mirror of https://github.com/0xERR0R/blocky.git
feat: support multiple hosts files
This commit is contained in:
parent
5e4c155793
commit
cfc3699ab5
|
@ -80,6 +80,7 @@ issues:
|
|||
# Exclude some linters from running on tests files.
|
||||
- path: _test\.go
|
||||
linters:
|
||||
- gochecknoglobals
|
||||
- dupl
|
||||
- funlen
|
||||
- gochecknoglobals
|
||||
- gosec
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
. "github.com/0xERR0R/blocky/config/migration" //nolint:revive,stylecheck
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -10,32 +8,40 @@ import (
|
|||
|
||||
// BlockingConfig configuration for query blocking
|
||||
type BlockingConfig struct {
|
||||
BlackLists map[string][]string `yaml:"blackLists"`
|
||||
WhiteLists map[string][]string `yaml:"whiteLists"`
|
||||
ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"`
|
||||
BlockType string `yaml:"blockType" default:"ZEROIP"`
|
||||
BlockTTL Duration `yaml:"blockTTL" default:"6h"`
|
||||
DownloadTimeout Duration `yaml:"downloadTimeout" default:"60s"`
|
||||
DownloadAttempts uint `yaml:"downloadAttempts" default:"3"`
|
||||
DownloadCooldown Duration `yaml:"downloadCooldown" default:"1s"`
|
||||
RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"`
|
||||
ProcessingConcurrency uint `yaml:"processingConcurrency" default:"4"`
|
||||
StartStrategy StartStrategyType `yaml:"startStrategy" default:"blocking"`
|
||||
MaxErrorsPerFile int `yaml:"maxErrorsPerFile" default:"5"`
|
||||
BlackLists map[string][]BytesSource `yaml:"blackLists"`
|
||||
WhiteLists map[string][]BytesSource `yaml:"whiteLists"`
|
||||
ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"`
|
||||
BlockType string `yaml:"blockType" default:"ZEROIP"`
|
||||
BlockTTL Duration `yaml:"blockTTL" default:"6h"`
|
||||
Loading SourceLoadingConfig `yaml:"loading"`
|
||||
|
||||
// Deprecated options
|
||||
Deprecated struct {
|
||||
FailStartOnListError *bool `yaml:"failStartOnListError"`
|
||||
DownloadTimeout *Duration `yaml:"downloadTimeout"`
|
||||
DownloadAttempts *uint `yaml:"downloadAttempts"`
|
||||
DownloadCooldown *Duration `yaml:"downloadCooldown"`
|
||||
RefreshPeriod *Duration `yaml:"refreshPeriod"`
|
||||
FailStartOnListError *bool `yaml:"failStartOnListError"`
|
||||
ProcessingConcurrency *uint `yaml:"processingConcurrency"`
|
||||
StartStrategy *StartStrategyType `yaml:"startStrategy"`
|
||||
MaxErrorsPerFile *int `yaml:"maxErrorsPerFile"`
|
||||
} `yaml:",inline"`
|
||||
}
|
||||
|
||||
func (c *BlockingConfig) migrate(logger *logrus.Entry) bool {
|
||||
return Migrate(logger, "blocking", c.Deprecated, map[string]Migrator{
|
||||
"failStartOnListError": Apply(To("startStrategy", c), func(oldValue bool) {
|
||||
if oldValue && c.StartStrategy != StartStrategyTypeFast {
|
||||
c.StartStrategy = StartStrategyTypeFailOnError
|
||||
"downloadTimeout": Move(To("loading.downloads.timeout", &c.Loading.Downloads)),
|
||||
"downloadAttempts": Move(To("loading.downloads.attempts", &c.Loading.Downloads)),
|
||||
"downloadCooldown": Move(To("loading.downloads.cooldown", &c.Loading.Downloads)),
|
||||
"refreshPeriod": Move(To("loading.refreshPeriod", &c.Loading)),
|
||||
"failStartOnListError": Apply(To("loading.strategy", &c.Loading), func(oldValue bool) {
|
||||
if oldValue {
|
||||
c.Loading.Strategy = StartStrategyTypeFailOnError
|
||||
}
|
||||
}),
|
||||
"processingConcurrency": Move(To("loading.concurrency", &c.Loading)),
|
||||
"startStrategy": Move(To("loading.strategy", &c.Loading)),
|
||||
"maxErrorsPerFile": Move(To("loading.maxErrorsPerSource", &c.Loading)),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -44,7 +50,7 @@ func (c *BlockingConfig) IsEnabled() bool {
|
|||
return len(c.ClientGroupsBlock) != 0
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *BlockingConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Info("clientGroupsBlock:")
|
||||
|
||||
|
@ -58,17 +64,8 @@ func (c *BlockingConfig) LogConfig(logger *logrus.Entry) {
|
|||
logger.Infof("blockTTL = %s", c.BlockTTL)
|
||||
}
|
||||
|
||||
logger.Infof("downloadTimeout = %s", c.DownloadTimeout)
|
||||
|
||||
logger.Infof("startStrategy = %s", c.StartStrategy)
|
||||
|
||||
logger.Infof("maxErrorsPerFile = %d", c.MaxErrorsPerFile)
|
||||
|
||||
if c.RefreshPeriod > 0 {
|
||||
logger.Infof("refresh = every %s", c.RefreshPeriod)
|
||||
} else {
|
||||
logger.Debug("refresh = disabled")
|
||||
}
|
||||
logger.Info("loading:")
|
||||
log.WithIndent(logger, " ", c.Loading.LogConfig)
|
||||
|
||||
logger.Info("blacklist:")
|
||||
log.WithIndent(logger, " ", func(logger *logrus.Entry) {
|
||||
|
@ -81,18 +78,12 @@ func (c *BlockingConfig) LogConfig(logger *logrus.Entry) {
|
|||
})
|
||||
}
|
||||
|
||||
func (c *BlockingConfig) logListGroups(logger *logrus.Entry, listGroups map[string][]string) {
|
||||
for group, links := range listGroups {
|
||||
func (c *BlockingConfig) logListGroups(logger *logrus.Entry, listGroups map[string][]BytesSource) {
|
||||
for group, sources := range listGroups {
|
||||
logger.Infof("%s:", group)
|
||||
|
||||
for _, link := range links {
|
||||
if idx := strings.IndexRune(link, '\n'); idx != -1 && idx < len(link) { // found and not last char
|
||||
link = link[:idx] // first line only
|
||||
|
||||
logger.Infof(" - %s [...]", link)
|
||||
} else {
|
||||
logger.Infof(" - %s", link)
|
||||
}
|
||||
for _, source := range sources {
|
||||
logger.Infof(" - %s", source)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var _ = Describe("BlockingConfig", func() {
|
||||
|
@ -18,13 +17,12 @@ var _ = Describe("BlockingConfig", func() {
|
|||
cfg = BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: Duration(time.Minute),
|
||||
BlackLists: map[string][]string{
|
||||
"gr1": {"/a/file/path"},
|
||||
BlackLists: map[string][]BytesSource{
|
||||
"gr1": NewBytesSources("/a/file/path"),
|
||||
},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"gr1"},
|
||||
},
|
||||
RefreshPeriod: Duration(time.Hour),
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -59,26 +57,7 @@ var _ = Describe("BlockingConfig", func() {
|
|||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages[0]).Should(Equal("clientGroupsBlock:"))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = every 1 hour")))
|
||||
})
|
||||
When("refresh is disabled", func() {
|
||||
It("should reflect that", func() {
|
||||
cfg.RefreshPeriod = Duration(-1)
|
||||
|
||||
logger.Logger.Level = logrus.InfoLevel
|
||||
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("refresh = disabled")))
|
||||
|
||||
logger.Logger.Level = logrus.TraceLevel
|
||||
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = disabled")))
|
||||
})
|
||||
Expect(hook.Messages).Should(ContainElement(Equal("blockType = ZEROIP")))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
//go:generate go run github.com/abice/go-enum -f=$GOFILE --marshal --names --values
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const maxTextSourceDisplayLen = 12
|
||||
|
||||
// var BytesSourceNone = BytesSource{}
|
||||
|
||||
// BytesSourceType supported BytesSource types. ENUM(
|
||||
// text=1 // Inline YAML block.
|
||||
// http // HTTP(S).
|
||||
// file // Local file.
|
||||
// )
|
||||
type BytesSourceType uint16
|
||||
|
||||
type BytesSource struct {
|
||||
Type BytesSourceType
|
||||
From string
|
||||
}
|
||||
|
||||
func (s BytesSource) String() string {
|
||||
switch s.Type {
|
||||
case BytesSourceTypeText:
|
||||
break
|
||||
|
||||
case BytesSourceTypeHttp:
|
||||
return s.From
|
||||
|
||||
case BytesSourceTypeFile:
|
||||
return fmt.Sprintf("file://%s", s.From)
|
||||
|
||||
default:
|
||||
return fmt.Sprintf("unknown source (%s: %s)", s.Type, s.From)
|
||||
}
|
||||
|
||||
text := s.From
|
||||
truncated := false
|
||||
|
||||
if idx := strings.IndexRune(text, '\n'); idx != -1 {
|
||||
text = text[:idx] // first line only
|
||||
truncated = idx < len(text) // don't count removing last char
|
||||
}
|
||||
|
||||
if len(text) > maxTextSourceDisplayLen { // truncate
|
||||
text = text[:maxTextSourceDisplayLen]
|
||||
truncated = true
|
||||
}
|
||||
|
||||
if truncated {
|
||||
return fmt.Sprintf("%s...", text[:maxTextSourceDisplayLen])
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
// UnmarshalText implements `encoding.TextUnmarshaler`.
|
||||
func (s *BytesSource) UnmarshalText(data []byte) error {
|
||||
source := string(data)
|
||||
|
||||
switch {
|
||||
// Inline definition in YAML (with literal style Block Scalar)
|
||||
case strings.ContainsAny(source, "\n"):
|
||||
*s = BytesSource{Type: BytesSourceTypeText, From: source}
|
||||
|
||||
// HTTP(S)
|
||||
case strings.HasPrefix(source, "http"):
|
||||
*s = BytesSource{Type: BytesSourceTypeHttp, From: source}
|
||||
|
||||
// Probably path to a local file
|
||||
default:
|
||||
*s = BytesSource{Type: BytesSourceTypeFile, From: strings.TrimPrefix(source, "file://")}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newBytesSource(source string) BytesSource {
|
||||
var res BytesSource
|
||||
|
||||
// UnmarshalText never returns an error
|
||||
_ = res.UnmarshalText([]byte(source))
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func NewBytesSources(sources ...string) []BytesSource {
|
||||
res := make([]BytesSource, 0, len(sources))
|
||||
|
||||
for _, source := range sources {
|
||||
res = append(res, newBytesSource(source))
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func TextBytesSource(lines ...string) BytesSource {
|
||||
return BytesSource{Type: BytesSourceTypeText, From: inlineList(lines...)}
|
||||
}
|
||||
|
||||
func inlineList(lines ...string) string {
|
||||
res := strings.Join(lines, "\n")
|
||||
|
||||
// ensure at least one line ending so it's parsed as an inline block
|
||||
res += "\n"
|
||||
|
||||
return res
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
// Code generated by go-enum DO NOT EDIT.
|
||||
// Version:
|
||||
// Revision:
|
||||
// Build Date:
|
||||
// Built By:
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// BytesSourceTypeText is a BytesSourceType of type Text.
|
||||
// Inline YAML block.
|
||||
BytesSourceTypeText BytesSourceType = iota + 1
|
||||
// BytesSourceTypeHttp is a BytesSourceType of type Http.
|
||||
// HTTP(S).
|
||||
BytesSourceTypeHttp
|
||||
// BytesSourceTypeFile is a BytesSourceType of type File.
|
||||
// Local file.
|
||||
BytesSourceTypeFile
|
||||
)
|
||||
|
||||
var ErrInvalidBytesSourceType = fmt.Errorf("not a valid BytesSourceType, try [%s]", strings.Join(_BytesSourceTypeNames, ", "))
|
||||
|
||||
const _BytesSourceTypeName = "texthttpfile"
|
||||
|
||||
var _BytesSourceTypeNames = []string{
|
||||
_BytesSourceTypeName[0:4],
|
||||
_BytesSourceTypeName[4:8],
|
||||
_BytesSourceTypeName[8:12],
|
||||
}
|
||||
|
||||
// BytesSourceTypeNames returns a list of possible string values of BytesSourceType.
|
||||
func BytesSourceTypeNames() []string {
|
||||
tmp := make([]string, len(_BytesSourceTypeNames))
|
||||
copy(tmp, _BytesSourceTypeNames)
|
||||
return tmp
|
||||
}
|
||||
|
||||
// BytesSourceTypeValues returns a list of the values for BytesSourceType
|
||||
func BytesSourceTypeValues() []BytesSourceType {
|
||||
return []BytesSourceType{
|
||||
BytesSourceTypeText,
|
||||
BytesSourceTypeHttp,
|
||||
BytesSourceTypeFile,
|
||||
}
|
||||
}
|
||||
|
||||
var _BytesSourceTypeMap = map[BytesSourceType]string{
|
||||
BytesSourceTypeText: _BytesSourceTypeName[0:4],
|
||||
BytesSourceTypeHttp: _BytesSourceTypeName[4:8],
|
||||
BytesSourceTypeFile: _BytesSourceTypeName[8:12],
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (x BytesSourceType) String() string {
|
||||
if str, ok := _BytesSourceTypeMap[x]; ok {
|
||||
return str
|
||||
}
|
||||
return fmt.Sprintf("BytesSourceType(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x BytesSourceType) IsValid() bool {
|
||||
_, ok := _BytesSourceTypeMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _BytesSourceTypeValue = map[string]BytesSourceType{
|
||||
_BytesSourceTypeName[0:4]: BytesSourceTypeText,
|
||||
_BytesSourceTypeName[4:8]: BytesSourceTypeHttp,
|
||||
_BytesSourceTypeName[8:12]: BytesSourceTypeFile,
|
||||
}
|
||||
|
||||
// ParseBytesSourceType attempts to convert a string to a BytesSourceType.
|
||||
func ParseBytesSourceType(name string) (BytesSourceType, error) {
|
||||
if x, ok := _BytesSourceTypeValue[name]; ok {
|
||||
return x, nil
|
||||
}
|
||||
return BytesSourceType(0), fmt.Errorf("%s is %w", name, ErrInvalidBytesSourceType)
|
||||
}
|
||||
|
||||
// MarshalText implements the text marshaller method.
|
||||
func (x BytesSourceType) MarshalText() ([]byte, error) {
|
||||
return []byte(x.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the text unmarshaller method.
|
||||
func (x *BytesSourceType) UnmarshalText(text []byte) error {
|
||||
name := string(text)
|
||||
tmp, err := ParseBytesSourceType(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*x = tmp
|
||||
return nil
|
||||
}
|
|
@ -60,4 +60,20 @@ var _ = Describe("CachingConfig", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("EnablePrefetch", func() {
|
||||
When("prefetching is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
cfg = CachingConfig{}
|
||||
})
|
||||
|
||||
It("should return configuration", func() {
|
||||
cfg.EnablePrefetch()
|
||||
|
||||
Expect(cfg.Prefetching).Should(BeTrue())
|
||||
Expect(cfg.PrefetchThreshold).Should(Equal(0))
|
||||
Expect(cfg.MaxCachingTime).ShouldNot(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
115
config/config.go
115
config/config.go
|
@ -2,6 +2,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -10,6 +11,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -32,7 +34,7 @@ type Configurable interface {
|
|||
|
||||
// LogConfig logs the receiver's configuration.
|
||||
//
|
||||
// Calling this method when `IsEnabled` returns false is undefined.
|
||||
// The behavior of this method is undefined when `IsEnabled` returns false.
|
||||
LogConfig(*logrus.Entry)
|
||||
}
|
||||
|
||||
|
@ -93,6 +95,30 @@ type QueryLogType int16
|
|||
// )
|
||||
type StartStrategyType uint16
|
||||
|
||||
func (s *StartStrategyType) do(setup func() error, logErr func(error)) error {
|
||||
if *s == StartStrategyTypeFast {
|
||||
go func() {
|
||||
err := setup()
|
||||
if err != nil {
|
||||
logErr(err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := setup()
|
||||
if err != nil {
|
||||
logErr(err)
|
||||
|
||||
if *s == StartStrategyTypeFailOnError {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryLogField data field to be logged
|
||||
// ENUM(clientIP,clientName,responseReason,responseAnswer,question,duration)
|
||||
type QueryLogField string
|
||||
|
@ -259,6 +285,86 @@ func (c *toEnable) LogConfig(logger *logrus.Entry) {
|
|||
logger.Info("enabled")
|
||||
}
|
||||
|
||||
type SourceLoadingConfig struct {
|
||||
Concurrency uint `yaml:"concurrency" default:"4"`
|
||||
MaxErrorsPerSource int `yaml:"maxErrorsPerSource" default:"5"`
|
||||
RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"`
|
||||
Strategy StartStrategyType `yaml:"strategy" default:"blocking"`
|
||||
Downloads DownloaderConfig `yaml:"downloads"`
|
||||
}
|
||||
|
||||
func (c *SourceLoadingConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Infof("concurrency = %d", c.Concurrency)
|
||||
logger.Debugf("maxErrorsPerSource = %d", c.MaxErrorsPerSource)
|
||||
logger.Debugf("strategy = %s", c.Strategy)
|
||||
|
||||
if c.RefreshPeriod > 0 {
|
||||
logger.Infof("refresh = every %s", c.RefreshPeriod)
|
||||
} else {
|
||||
logger.Debug("refresh = disabled")
|
||||
}
|
||||
|
||||
logger.Info("downloads:")
|
||||
log.WithIndent(logger, " ", c.Downloads.LogConfig)
|
||||
}
|
||||
|
||||
func (c *SourceLoadingConfig) StartPeriodicRefresh(refresh func(context.Context) error, logErr func(error)) error {
|
||||
refreshAndRecover := func(ctx context.Context) (rerr error) {
|
||||
defer func() {
|
||||
if val := recover(); val != nil {
|
||||
rerr = fmt.Errorf("refresh function panicked: %v", val)
|
||||
}
|
||||
}()
|
||||
|
||||
return refresh(ctx)
|
||||
}
|
||||
|
||||
err := c.Strategy.do(func() error { return refreshAndRecover(context.Background()) }, logErr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.RefreshPeriod > 0 {
|
||||
go c.periodically(refreshAndRecover, logErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SourceLoadingConfig) periodically(refresh func(context.Context) error, logErr func(error)) {
|
||||
ticker := time.NewTicker(c.RefreshPeriod.ToDuration())
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
err := refresh(context.Background())
|
||||
if err != nil {
|
||||
logErr(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type DownloaderConfig struct {
|
||||
Timeout Duration `yaml:"timeout" default:"5s"`
|
||||
Attempts uint `yaml:"attempts" default:"3"`
|
||||
Cooldown Duration `yaml:"cooldown" default:"500ms"`
|
||||
}
|
||||
|
||||
func (c *DownloaderConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Infof("timeout = %s", c.Timeout)
|
||||
logger.Infof("attempts = %d", c.Attempts)
|
||||
logger.Debugf("cooldown = %s", c.Cooldown)
|
||||
}
|
||||
|
||||
func WithDefaults[T any]() (T, error) {
|
||||
var cfg T
|
||||
|
||||
if err := defaults.Set(&cfg); err != nil {
|
||||
return cfg, fmt.Errorf("can't apply %T defaults: %w", cfg, err)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var (
|
||||
config = &Config{}
|
||||
|
@ -270,9 +376,9 @@ func LoadConfig(path string, mandatory bool) (*Config, error) {
|
|||
cfgLock.Lock()
|
||||
defer cfgLock.Unlock()
|
||||
|
||||
cfg := Config{}
|
||||
if err := defaults.Set(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("can't apply default values: %w", err)
|
||||
cfg, err := WithDefaults[Config]()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fs, err := os.Stat(path)
|
||||
|
@ -398,6 +504,7 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool {
|
|||
})
|
||||
|
||||
usesDepredOpts = cfg.Blocking.migrate(logger) || usesDepredOpts
|
||||
usesDepredOpts = cfg.HostsFile.migrate(logger) || usesDepredOpts
|
||||
|
||||
return usesDepredOpts
|
||||
}
|
||||
|
|
|
@ -63,6 +63,13 @@ func (x IPVersion) String() string {
|
|||
return fmt.Sprintf("IPVersion(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x IPVersion) IsValid() bool {
|
||||
_, ok := _IPVersionMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _IPVersionValue = map[string]IPVersion{
|
||||
_IPVersionName[0:4]: IPVersionDual,
|
||||
_IPVersionName[4:6]: IPVersionV4,
|
||||
|
@ -145,6 +152,13 @@ func (x NetProtocol) String() string {
|
|||
return fmt.Sprintf("NetProtocol(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x NetProtocol) IsValid() bool {
|
||||
_, ok := _NetProtocolMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _NetProtocolValue = map[string]NetProtocol{
|
||||
_NetProtocolName[0:7]: NetProtocolTcpUdp,
|
||||
_NetProtocolName[7:14]: NetProtocolTcpTls,
|
||||
|
@ -225,7 +239,8 @@ func (x QueryLogField) String() string {
|
|||
return string(x)
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x QueryLogField) IsValid() bool {
|
||||
_, err := ParseQueryLogField(string(x))
|
||||
return err == nil
|
||||
|
@ -333,6 +348,13 @@ func (x QueryLogType) String() string {
|
|||
return fmt.Sprintf("QueryLogType(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x QueryLogType) IsValid() bool {
|
||||
_, ok := _QueryLogTypeMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _QueryLogTypeValue = map[string]QueryLogType{
|
||||
_QueryLogTypeName[0:7]: QueryLogTypeConsole,
|
||||
_QueryLogTypeName[7:11]: QueryLogTypeNone,
|
||||
|
@ -418,6 +440,13 @@ func (x StartStrategyType) String() string {
|
|||
return fmt.Sprintf("StartStrategyType(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x StartStrategyType) IsValid() bool {
|
||||
_, ok := _StartStrategyTypeMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _StartStrategyTypeValue = map[string]StartStrategyType{
|
||||
_StartStrategyTypeName[0:8]: StartStrategyTypeBlocking,
|
||||
_StartStrategyTypeName[8:19]: StartStrategyTypeFailOnError,
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/creasty/defaults"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
|
@ -48,15 +52,17 @@ var _ = Describe("Config", func() {
|
|||
BeforeEach(func() {
|
||||
c.Blocking.Deprecated.FailStartOnListError = ptrOf(true)
|
||||
})
|
||||
It("should change StartStrategy blocking to failOnError", func() {
|
||||
c.Blocking.StartStrategy = StartStrategyTypeBlocking
|
||||
It("should change loading.strategy blocking to failOnError", func() {
|
||||
c.Blocking.Loading.Strategy = StartStrategyTypeBlocking
|
||||
c.migrate(logger)
|
||||
Expect(c.Blocking.StartStrategy).Should(Equal(StartStrategyTypeFailOnError))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("blocking.loading.strategy")))
|
||||
Expect(c.Blocking.Loading.Strategy).Should(Equal(StartStrategyTypeFailOnError))
|
||||
})
|
||||
It("shouldn't change StartStrategy if set to fast", func() {
|
||||
c.Blocking.StartStrategy = StartStrategyTypeFast
|
||||
It("shouldn't change loading.strategy if set to fast", func() {
|
||||
c.Blocking.Loading.Strategy = StartStrategyTypeFast
|
||||
c.migrate(logger)
|
||||
Expect(c.Blocking.StartStrategy).Should(Equal(StartStrategyTypeFast))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("blocking.loading.strategy")))
|
||||
Expect(c.Blocking.Loading.Strategy).Should(Equal(StartStrategyTypeFast))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -206,8 +212,10 @@ var _ = Describe("Config", func() {
|
|||
When("duration is in wrong format", func() {
|
||||
It("should return error", func() {
|
||||
cfg := Config{}
|
||||
data := `blocking:
|
||||
refreshPeriod: wrongduration`
|
||||
data := `
|
||||
blocking:
|
||||
loading:
|
||||
refreshPeriod: wrongduration`
|
||||
err := unmarshalConfig([]byte(data), &cfg)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("invalid duration \"wrongduration\""))
|
||||
|
@ -534,6 +542,222 @@ bootstrapDns:
|
|||
"tcp-tls:[fd00::6cd4:d7e0:d99d:2952]",
|
||||
),
|
||||
)
|
||||
|
||||
Describe("SourceLoadingConfig", func() {
|
||||
var cfg SourceLoadingConfig
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = SourceLoadingConfig{
|
||||
Concurrency: 12,
|
||||
RefreshPeriod: Duration(time.Hour),
|
||||
}
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
It("should log configuration", func() {
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages[0]).Should(Equal("concurrency = 12"))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = every 1 hour")))
|
||||
})
|
||||
When("refresh is disabled", func() {
|
||||
BeforeEach(func() {
|
||||
cfg.RefreshPeriod = Duration(-1)
|
||||
})
|
||||
|
||||
It("should reflect that", func() {
|
||||
logger.Logger.Level = logrus.InfoLevel
|
||||
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("refresh = disabled")))
|
||||
|
||||
logger.Logger.Level = logrus.TraceLevel
|
||||
|
||||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("refresh = disabled")))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("StartStrategyType", func() {
|
||||
Describe("StartStrategyTypeBlocking", func() {
|
||||
It("runs in the current goroutine", func() {
|
||||
sut := StartStrategyTypeBlocking
|
||||
panicVal := new(int)
|
||||
|
||||
defer func() {
|
||||
// recover will catch the panic if it happened in the same goroutine
|
||||
Expect(recover()).Should(BeIdenticalTo(panicVal))
|
||||
}()
|
||||
|
||||
_ = sut.do(func() error {
|
||||
panic(panicVal)
|
||||
}, nil)
|
||||
|
||||
Fail("unreachable")
|
||||
})
|
||||
|
||||
It("logs errors and doesn't return them", func() {
|
||||
sut := StartStrategyTypeBlocking
|
||||
expectedErr := errors.New("test")
|
||||
|
||||
err := sut.do(func() error {
|
||||
return expectedErr
|
||||
}, func(err error) {
|
||||
Expect(err).Should(MatchError(expectedErr))
|
||||
})
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("StartStrategyTypeFailOnError", func() {
|
||||
It("runs in the current goroutine", func() {
|
||||
sut := StartStrategyTypeBlocking
|
||||
panicVal := new(int)
|
||||
|
||||
defer func() {
|
||||
// recover will catch the panic if it happened in the same goroutine
|
||||
Expect(recover()).Should(BeIdenticalTo(panicVal))
|
||||
}()
|
||||
|
||||
_ = sut.do(func() error {
|
||||
panic(panicVal)
|
||||
}, nil)
|
||||
|
||||
Fail("unreachable")
|
||||
})
|
||||
|
||||
It("logs errors and returns them", func() {
|
||||
sut := StartStrategyTypeFailOnError
|
||||
expectedErr := errors.New("test")
|
||||
|
||||
err := sut.do(func() error {
|
||||
return expectedErr
|
||||
}, func(err error) {
|
||||
Expect(err).Should(MatchError(expectedErr))
|
||||
})
|
||||
|
||||
Expect(err).Should(MatchError(expectedErr))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("StartStrategyTypeFast", func() {
|
||||
It("runs in a new goroutine", func() {
|
||||
sut := StartStrategyTypeFast
|
||||
events := make(chan string)
|
||||
wait := make(chan struct{})
|
||||
|
||||
err := sut.do(func() error {
|
||||
events <- "start"
|
||||
<-wait
|
||||
events <- "done"
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
|
||||
Eventually(events, "50ms").Should(Receive(Equal("start")))
|
||||
Expect(err).Should(Succeed())
|
||||
Consistently(events).ShouldNot(Receive())
|
||||
close(wait)
|
||||
Eventually(events, "50ms").Should(Receive(Equal("done")))
|
||||
})
|
||||
|
||||
It("logs errors", func() {
|
||||
sut := StartStrategyTypeFast
|
||||
expectedErr := errors.New("test")
|
||||
wait := make(chan struct{})
|
||||
|
||||
err := sut.do(func() error {
|
||||
return expectedErr
|
||||
}, func(err error) {
|
||||
Expect(err).Should(MatchError(expectedErr))
|
||||
close(wait)
|
||||
})
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
Eventually(wait, "50ms").Should(BeClosed())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("SourceLoadingConfig", func() {
|
||||
It("handles panics", func() {
|
||||
sut := SourceLoadingConfig{
|
||||
Strategy: StartStrategyTypeFailOnError,
|
||||
}
|
||||
|
||||
panicMsg := "panic value"
|
||||
|
||||
err := sut.StartPeriodicRefresh(func(context.Context) error {
|
||||
panic(panicMsg)
|
||||
}, func(err error) {
|
||||
Expect(err).Should(MatchError(ContainSubstring(panicMsg)))
|
||||
})
|
||||
|
||||
Expect(err).Should(MatchError(ContainSubstring(panicMsg)))
|
||||
})
|
||||
|
||||
It("periodically calls refresh", func() {
|
||||
sut := SourceLoadingConfig{
|
||||
Strategy: StartStrategyTypeFast,
|
||||
RefreshPeriod: Duration(5 * time.Millisecond),
|
||||
}
|
||||
|
||||
panicMsg := "panic value"
|
||||
calls := make(chan int32)
|
||||
|
||||
var call atomic.Int32
|
||||
|
||||
err := sut.StartPeriodicRefresh(func(context.Context) error {
|
||||
call := call.Add(1)
|
||||
calls <- call
|
||||
|
||||
if call == 3 {
|
||||
panic(panicMsg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, func(err error) {
|
||||
defer GinkgoRecover()
|
||||
|
||||
Expect(err).Should(MatchError(ContainSubstring(panicMsg)))
|
||||
Expect(call.Load()).Should(Equal(int32(3)))
|
||||
})
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
Eventually(calls, "50ms").Should(Receive(Equal(int32(1))))
|
||||
Eventually(calls, "50ms").Should(Receive(Equal(int32(2))))
|
||||
Eventually(calls, "50ms").Should(Receive(Equal(int32(3))))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("WithDefaults", func() {
|
||||
It("use valid defaults", func() {
|
||||
type T struct {
|
||||
X int `default:"1"`
|
||||
}
|
||||
|
||||
t, err := WithDefaults[T]()
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(t.X).Should(Equal(1))
|
||||
})
|
||||
|
||||
It("return an error if the tag is invalid", func() {
|
||||
type T struct {
|
||||
X struct{} `default:"fail"`
|
||||
}
|
||||
|
||||
_, err := WithDefaults[T]()
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func defaultTestFileConfig() {
|
||||
|
@ -557,7 +781,7 @@ func defaultTestFileConfig() {
|
|||
Expect(config.Blocking.WhiteLists).Should(HaveLen(1))
|
||||
Expect(config.Blocking.ClientGroupsBlock).Should(HaveLen(2))
|
||||
Expect(config.Blocking.BlockTTL).Should(Equal(Duration(time.Minute)))
|
||||
Expect(config.Blocking.RefreshPeriod).Should(Equal(Duration(2 * time.Hour)))
|
||||
Expect(config.Blocking.Loading.RefreshPeriod).Should(Equal(Duration(2 * time.Hour)))
|
||||
Expect(config.Filtering.QueryTypes).Should(HaveLen(2))
|
||||
Expect(config.FqdnOnly.Enable).Should(BeTrue())
|
||||
|
||||
|
@ -613,7 +837,8 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile {
|
|||
" Laptop-D.fritz.box:",
|
||||
" - ads",
|
||||
" blockTTL: 1m",
|
||||
" refreshPeriod: 120",
|
||||
" loading:",
|
||||
" refreshPeriod: 120",
|
||||
"clientLookup:",
|
||||
" upstream: 192.168.178.1",
|
||||
" singleNameOrder:",
|
||||
|
@ -629,7 +854,6 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile {
|
|||
"startVerifyUpstream: false")
|
||||
}
|
||||
|
||||
//nolint:funlen
|
||||
func writeConfigDir(tmpDir *helpertest.TmpFolder) error {
|
||||
f1 := tmpDir.CreateStringFile("config1.yaml",
|
||||
"upstream:",
|
||||
|
@ -675,7 +899,8 @@ func writeConfigDir(tmpDir *helpertest.TmpFolder) error {
|
|||
" Laptop-D.fritz.box:",
|
||||
" - ads",
|
||||
" blockTTL: 1m",
|
||||
" refreshPeriod: 120",
|
||||
" loading:",
|
||||
" refreshPeriod: 120",
|
||||
"clientLookup:",
|
||||
" upstream: 192.168.178.1",
|
||||
" singleNameOrder:",
|
||||
|
|
|
@ -1,25 +1,49 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
. "github.com/0xERR0R/blocky/config/migration" //nolint:revive,stylecheck
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type HostsFileConfig struct {
|
||||
Filepath string `yaml:"filePath"`
|
||||
HostsTTL Duration `yaml:"hostsTTL" default:"1h"`
|
||||
RefreshPeriod Duration `yaml:"refreshPeriod" default:"1h"`
|
||||
FilterLoopback bool `yaml:"filterLoopback"`
|
||||
Sources []BytesSource `yaml:"sources"`
|
||||
HostsTTL Duration `yaml:"hostsTTL" default:"1h"`
|
||||
FilterLoopback bool `yaml:"filterLoopback"`
|
||||
Loading SourceLoadingConfig `yaml:"loading"`
|
||||
|
||||
// Deprecated options
|
||||
Deprecated struct {
|
||||
RefreshPeriod *Duration `yaml:"refreshPeriod"`
|
||||
Filepath *BytesSource `yaml:"filePath"`
|
||||
} `yaml:",inline"`
|
||||
}
|
||||
|
||||
func (c *HostsFileConfig) migrate(logger *logrus.Entry) bool {
|
||||
return Migrate(logger, "hostsFile", c.Deprecated, map[string]Migrator{
|
||||
"refreshPeriod": Move(To("loading.refreshPeriod", &c.Loading)),
|
||||
"filePath": Apply(To("sources", c), func(value BytesSource) {
|
||||
c.Sources = append(c.Sources, value)
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *HostsFileConfig) IsEnabled() bool {
|
||||
return len(c.Filepath) != 0
|
||||
return len(c.Sources) != 0
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (c *HostsFileConfig) LogConfig(logger *logrus.Entry) {
|
||||
logger.Infof("file path: %s", c.Filepath)
|
||||
logger.Infof("TTL: %s", c.HostsTTL)
|
||||
logger.Infof("refresh period: %s", c.RefreshPeriod)
|
||||
logger.Infof("filter loopback addresses: %t", c.FilterLoopback)
|
||||
|
||||
logger.Info("loading:")
|
||||
log.WithIndent(logger, " ", c.Loading.LogConfig)
|
||||
|
||||
logger.Info("sources:")
|
||||
|
||||
for _, source := range c.Sources {
|
||||
logger.Infof(" - %s", source)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,9 +15,12 @@ var _ = Describe("HostsFileConfig", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
cfg = HostsFileConfig{
|
||||
Filepath: "/dev/null",
|
||||
Sources: append(
|
||||
NewBytesSources("/a/file/path"),
|
||||
TextBytesSource("127.0.0.1 localhost"),
|
||||
),
|
||||
HostsTTL: Duration(29 * time.Minute),
|
||||
RefreshPeriod: Duration(30 * time.Minute),
|
||||
Loading: SourceLoadingConfig{RefreshPeriod: Duration(30 * time.Minute)},
|
||||
FilterLoopback: true,
|
||||
}
|
||||
})
|
||||
|
@ -50,7 +53,28 @@ var _ = Describe("HostsFileConfig", func() {
|
|||
cfg.LogConfig(logger)
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("file path: /dev/null")))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("- file:///a/file/path")))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("- 127.0.0.1 lo...")))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("migrate", func() {
|
||||
It("should", func() {
|
||||
cfg, err := WithDefaults[HostsFileConfig]()
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
cfg.Deprecated.Filepath = ptrOf(newBytesSource("/a/file/path"))
|
||||
cfg.Deprecated.RefreshPeriod = ptrOf(Duration(time.Hour))
|
||||
|
||||
migrated := cfg.migrate(logger)
|
||||
Expect(migrated).Should(BeTrue())
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("hostsFile.loading.refreshPeriod")))
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("hostsFile.sources")))
|
||||
|
||||
Expect(cfg.Sources).Should(Equal([]BytesSource{*cfg.Deprecated.Filepath}))
|
||||
Expect(cfg.Loading.RefreshPeriod).Should(Equal(*cfg.Deprecated.RefreshPeriod))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# REVIEW: manual changelog entry
|
||||
|
||||
upstream:
|
||||
# these external DNS resolvers will be used. Blocky picks 2 random resolvers from the list for each query
|
||||
# format for resolver: [net:]host:[port][/path]. net could be empty (default, shortcut for tcp+udp), tcp+udp, tcp, udp, tcp-tls or https (DoH). If port is empty, default port will be used (53 for udp and tcp, 853 for tcp-tls, 443 for https (Doh))
|
||||
|
@ -99,22 +101,33 @@ blocking:
|
|||
# optional: TTL for answers to blocked domains
|
||||
# default: 6h
|
||||
blockTTL: 1m
|
||||
# optional: automatically list refresh period (in duration format). Default: 4h.
|
||||
# Negative value -> deactivate automatically refresh.
|
||||
# 0 value -> use default
|
||||
refreshPeriod: 4h
|
||||
# optional: timeout for list download (each url). Default: 60s. Use large values for big lists or slow internet connections
|
||||
downloadTimeout: 4m
|
||||
# optional: Download attempt timeout. Default: 60s
|
||||
downloadAttempts: 5
|
||||
# optional: Time between the download attempts. Default: 1s
|
||||
downloadCooldown: 10s
|
||||
# optional: if failOnError, application startup will fail if at least one list can't be downloaded / opened. Default: blocking
|
||||
startStrategy: failOnError
|
||||
# Number of errors allowed in a list before it is considered invalid.
|
||||
# A value of -1 disables the limit.
|
||||
# Default: 5
|
||||
maxErrorsPerFile: 5
|
||||
# optional: Configure how lists, AKA sources, are loaded
|
||||
loading:
|
||||
# optional: list refresh period in duration format.
|
||||
# Set to a value <= 0 to disable.
|
||||
# default: 4h
|
||||
refreshPeriod: 24h
|
||||
# optional: Applies only to lists that are downloaded (HTTP URLs).
|
||||
downloads:
|
||||
# optional: timeout for list download (each url). Use large values for big lists or slow internet connections
|
||||
# default: 5s
|
||||
timeout: 60s
|
||||
# optional: Maximum download attempts
|
||||
# default: 3
|
||||
attempts: 5
|
||||
# optional: Time between the download attempts
|
||||
# default: 500ms
|
||||
cooldown: 10s
|
||||
# optional: Maximum number of lists to process in parallel.
|
||||
# default: 4
|
||||
concurrency: 16
|
||||
# optional: if failOnError, application startup will fail if at least one list can't be downloaded/opened
|
||||
# default: blocking
|
||||
strategy: failOnError
|
||||
# Number of errors allowed in a list before it is considered invalid.
|
||||
# A value of -1 disables the limit.
|
||||
# default: 5
|
||||
maxErrorsPerSource: 5
|
||||
|
||||
# optional: configuration for caching of DNS responses
|
||||
caching:
|
||||
|
@ -161,6 +174,7 @@ clientLookup:
|
|||
clients:
|
||||
laptop:
|
||||
- 192.168.178.29
|
||||
|
||||
# optional: configuration for prometheus metrics endpoint
|
||||
prometheus:
|
||||
# enabled if true
|
||||
|
@ -214,6 +228,7 @@ redis:
|
|||
|
||||
# optional: Mininal TLS version that the DoH and DoT server will use
|
||||
minTlsServeVersion: 1.3
|
||||
|
||||
# if https port > 0: path to cert and key file for SSL encryption. if not set, self-signed certificate will be generated
|
||||
#certFile: server.crt
|
||||
#keyFile: server.key
|
||||
|
@ -238,14 +253,45 @@ fqdnOnly:
|
|||
|
||||
# optional: if path defined, use this file for query resolution (A, AAAA and rDNS). Default: empty
|
||||
hostsFile:
|
||||
# optional: Path to hosts file (e.g. /etc/hosts on Linux)
|
||||
filePath: /etc/hosts
|
||||
# optional: Hosts files to parse
|
||||
sources:
|
||||
- /etc/hosts
|
||||
- https://example.com/hosts
|
||||
- |
|
||||
# inline hosts
|
||||
127.0.0.1 example.com
|
||||
# optional: TTL, default: 1h
|
||||
hostsTTL: 60m
|
||||
# optional: Time between hosts file refresh, default: 1h
|
||||
refreshPeriod: 30m
|
||||
# optional: Whether loopback hosts addresses (127.0.0.0/8 and ::1) should be filtered or not, default: false
|
||||
hostsTTL: 30m
|
||||
# optional: Whether loopback hosts addresses (127.0.0.0/8 and ::1) should be filtered or not
|
||||
# default: false
|
||||
filterLoopback: true
|
||||
# optional: Configure how sources are loaded
|
||||
loading:
|
||||
# optional: file refresh period in duration format.
|
||||
# Set to a value <= 0 to disable.
|
||||
# default: 4h
|
||||
refreshPeriod: 24h
|
||||
# optional: Applies only to files that are downloaded (HTTP URLs).
|
||||
downloads:
|
||||
# optional: timeout for file download (each url). Use large values for big files or slow internet connections
|
||||
# default: 5s
|
||||
timeout: 60s
|
||||
# optional: Maximum download attempts
|
||||
# default: 3
|
||||
attempts: 5
|
||||
# optional: Time between the download attempts
|
||||
# default: 500ms
|
||||
cooldown: 10s
|
||||
# optional: Maximum number of files to process in parallel.
|
||||
# default: 4
|
||||
concurrency: 16
|
||||
# optional: if failOnError, application startup will fail if at least one file can't be downloaded/opened
|
||||
# default: blocking
|
||||
strategy: failOnError
|
||||
# Number of errors allowed in a file before it is considered invalid.
|
||||
# A value of -1 disables the limit.
|
||||
# default: 5
|
||||
maxErrorsPerSource: 5
|
||||
|
||||
# optional: ports configuration
|
||||
ports:
|
||||
|
@ -272,4 +318,4 @@ log:
|
|||
# optional: add EDE error codes to dns response
|
||||
ede:
|
||||
# enabled if true, Default: false
|
||||
enable: true
|
||||
enable: true
|
||||
|
|
|
@ -330,20 +330,24 @@ contains a map of client name and multiple IP addresses.
|
|||
|
||||
## Blocking and whitelisting
|
||||
|
||||
Blocky can download and use external lists with domains or IP addresses to block DNS query (e.g. advertisement, malware,
|
||||
Blocky can use lists of domains and IPs to block (e.g. advertisement, malware,
|
||||
trackers, adult sites). You can group several list sources together and define the blocking behavior per client.
|
||||
External blacklists must be either in the well-known [Hosts format](https://en.wikipedia.org/wiki/Hosts_(file)) or just
|
||||
a plain domain list (one domain per line). Blocky also supports regex as more powerful tool to define patterns to block.
|
||||
Blocking uses the [DNS sinkhole](https://en.wikipedia.org/wiki/DNS_sinkhole) approach. For each DNS query, the domain name from
|
||||
the request, IP address from the response, and any CNAME records will be checked to determine whether to block the query or not.
|
||||
|
||||
Blocky uses [DNS sinkhole](https://en.wikipedia.org/wiki/DNS_sinkhole) approach to block a DNS query. Domain name from
|
||||
the request, IP address from the response, and the CNAME record will be checked against configured blacklists.
|
||||
|
||||
To avoid over-blocking, you can define or use already existing whitelists.
|
||||
To avoid over-blocking, you can use whitelists.
|
||||
|
||||
### Definition black and whitelists
|
||||
|
||||
Each black or whitelist can be either a path to the local file, a URL to download or inline list definition of a domains
|
||||
in hosts format (YAML literal block scalar style). All Urls must be grouped to a group name.
|
||||
Lists are defined in groups. This allows using different sets of lists for different clients.
|
||||
|
||||
Each list in a group is a "source" and can be downloaded, read from a file, or inlined in the config. See [Sources](#sources) for details and configuring how those are loaded and reloaded/refreshed.
|
||||
|
||||
The supported list formats are:
|
||||
|
||||
1. the well-known [Hosts format](https://en.wikipedia.org/wiki/Hosts_(file))
|
||||
2. one domain per line (plain domain list)
|
||||
3. one regex per line
|
||||
|
||||
!!! example
|
||||
|
||||
|
@ -354,35 +358,38 @@ in hosts format (YAML literal block scalar style). All Urls must be grouped to a
|
|||
- https://s3.amazonaws.com/lists.disconnect.me/simple_ad.txt
|
||||
- https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts
|
||||
- |
|
||||
# inline definition with YAML literal block scalar style
|
||||
# inline definition using YAML literal block scalar style
|
||||
# content is in plain domain list format
|
||||
someadsdomain.com
|
||||
anotheradsdomain.com
|
||||
# this is a regex
|
||||
- |
|
||||
# inline definition with a regex
|
||||
/^banners?[_.-]/
|
||||
special:
|
||||
- https://raw.githubusercontent.com/StevenBlack/hosts/master/alternates/fakenews/hosts
|
||||
whiteLists:
|
||||
ads:
|
||||
- whitelist.txt
|
||||
- /path/to/file.txt
|
||||
- |
|
||||
# inline definition with YAML literal block scalar style
|
||||
whitelistdomain.com
|
||||
```
|
||||
|
||||
In this example you can see 2 groups: **ads** with 2 lists and **special** with one list. One local whitelist was defined for the **ads** group.
|
||||
In this example you can see 2 groups: **ads** and **special** with one list. The **ads** group includes 2 inline lists.
|
||||
|
||||
!!! warning
|
||||
|
||||
If the same group has black and whitelists, whitelists will be used to disable particular blacklist entries.
|
||||
If a group has **only** whitelist entries -> this means only domains from this list are allowed, all other domains will
|
||||
be blocked
|
||||
be blocked.
|
||||
|
||||
!!! note
|
||||
Please define also client group mapping, otherwise you black and whitelist definition will have no effect
|
||||
!!! warning
|
||||
You must also define client group mapping, otherwise you black and whitelist definition will have no effect.
|
||||
|
||||
#### Regex support
|
||||
|
||||
You can use regex to define patterns to block. A regex entry must start and end with the slash character (/). Some
|
||||
You can use regex to define patterns to block. A regex entry must start and end with the slash character (`/`). Some
|
||||
Examples:
|
||||
|
||||
- `/baddomain/` will block `www.baddomain.com`, `baddomain.com`, but also `mybaddomain-sometext.com`
|
||||
|
@ -395,7 +402,7 @@ In this configuration section, you can define, which blocking group(s) should be
|
|||
Example: All clients should use the **ads** group, which blocks advertisement and kids devices should use the **adult**
|
||||
group, which blocky adult sites.
|
||||
|
||||
Clients without a group assignment will use automatically the **default** group.
|
||||
Clients without an explicit group assignment will use the **default** group.
|
||||
|
||||
You can use the client name (see [Client name lookup](#client-name-lookup)), client's IP address, client's full-qualified domain name
|
||||
or a client subnet as CIDR notation.
|
||||
|
@ -460,82 +467,9 @@ after receiving the custom value.
|
|||
blockTTL: 10s
|
||||
```
|
||||
|
||||
### List refresh period
|
||||
### Lists Loading
|
||||
|
||||
To keep the list cache up-to-date, blocky will periodically download and reload all external lists. Default period is **
|
||||
4 hours**. You can configure this by setting the `blocking.refreshPeriod` parameter to a value in **duration format**.
|
||||
Negative value will deactivate automatically refresh.
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
blocking:
|
||||
refreshPeriod: 60m
|
||||
```
|
||||
|
||||
Refresh every hour.
|
||||
|
||||
### Download
|
||||
|
||||
You can configure the list download attempts according to your internet connection:
|
||||
|
||||
| Parameter | Type | Mandatory | Default value | Description |
|
||||
|------------------|-----------------|-----------|---------------|------------------------------------------------|
|
||||
| downloadTimeout | duration format | no | 60s | Download attempt timeout |
|
||||
| downloadAttempts | int | no | 3 | How many download attempts should be performed |
|
||||
| downloadCooldown | duration format | no | 1s | Time between the download attempts |
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
blocking:
|
||||
downloadTimeout: 4m
|
||||
downloadAttempts: 5
|
||||
downloadCooldown: 10s
|
||||
```
|
||||
|
||||
### Start strategy
|
||||
|
||||
You can configure the blocking behavior during application start of blocky.
|
||||
If no strategy is selected blocking will be used.
|
||||
|
||||
| startStrategy | Description |
|
||||
|---------------|-------------------------------------------------------------------------------------------------------|
|
||||
| blocking | all blocking lists will be loaded before DNS resolution starts |
|
||||
| failOnError | like blocking but blocky will shut down if any download fails |
|
||||
| fast | DNS resolution starts immediately without blocking which will be enabled after list load is completed |
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
blocking:
|
||||
startStrategy: failOnError
|
||||
```
|
||||
|
||||
### Max Errors per file
|
||||
|
||||
Number of errors allowed in a list before it is considered invalid and parsing stops.
|
||||
A value of -1 disables the limit.
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
blocking:
|
||||
maxErrorsPerFile: 10
|
||||
```
|
||||
|
||||
### Concurrency
|
||||
|
||||
Blocky downloads and processes links in a single group concurrently. With parameter `processingConcurrency` you can adjust
|
||||
how many links can be processed in the same time. Higher value can reduce the overall list refresh time, but more parallel
|
||||
download and processing jobs need more RAM. Please consider to reduce this value on systems with limited memory. Default value is 4.
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
blocking:
|
||||
processingConcurrency: 10
|
||||
```
|
||||
See [Sources Loading](#sources-loading).
|
||||
|
||||
## Caching
|
||||
|
||||
|
@ -716,7 +650,7 @@ Configuration parameters:
|
|||
```yaml
|
||||
hostsFile:
|
||||
filePath: /etc/hosts
|
||||
hostsTTL: 60m
|
||||
hostsTTL: 1h
|
||||
refreshPeriod: 30m
|
||||
```
|
||||
|
||||
|
@ -745,3 +679,127 @@ for detailed information, how to create and configure SSL certificates.
|
|||
DoH url: `https://host:port/dns-query`
|
||||
|
||||
--8<-- "docs/includes/abbreviations.md"
|
||||
|
||||
## Sources
|
||||
|
||||
Sources are a concept shared by the blocking and hosts file resolvers. They represent where to load the files for each resolver.
|
||||
|
||||
The supported source types are:
|
||||
|
||||
- HTTP(S) URL (any source starting with `http`)
|
||||
- inline configuration (any source containing a newline)
|
||||
- local file path (any source not matching the above rules)
|
||||
|
||||
!!! note
|
||||
|
||||
The format/content of the sources depends on the context: lists and hosts files have different, but overlapping, supported formats.
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
- https://example.com/a/source # blocky will download and parse the file
|
||||
- /a/file/path # blocky will read the local file
|
||||
- | # blocky will parse the content of this multi-line string
|
||||
# inline configuration
|
||||
```
|
||||
|
||||
### Sources Loading
|
||||
|
||||
This sections covers `loading` configuration that applies to both the blocking and hosts file resolvers.
|
||||
These settings apply only to the resolver under which they are nested.
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
blocking:
|
||||
loading:
|
||||
# only applies to white/blacklists
|
||||
|
||||
hostsFile:
|
||||
loading:
|
||||
# only applies to hostsFile sources
|
||||
```
|
||||
|
||||
#### Refresh / Reload
|
||||
|
||||
To keep source contents up-to-date, blocky can periodically refresh and reparse them. Default period is **
|
||||
4 hours**. You can configure this by setting the `refreshPeriod` parameter to a value in **duration format**.
|
||||
A value of zero or less will disable this feature.
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
loading:
|
||||
refreshPeriod: 1h
|
||||
```
|
||||
|
||||
Refresh every hour.
|
||||
|
||||
### Downloads
|
||||
|
||||
Configures how HTTP(S) sources are downloaded:
|
||||
|
||||
| Parameter | Type | Mandatory | Default value | Description |
|
||||
|-----------|----------|-----------|---------------|------------------------------------------------|
|
||||
| timeout | duration | no | 5s | Download attempt timeout |
|
||||
| attempts | int | no | 3 | How many download attempts should be performed |
|
||||
| cooldown | duration | no | 500ms | Time between the download attempts |
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
loading:
|
||||
downloads:
|
||||
timeout: 4m
|
||||
attempts: 5
|
||||
cooldown: 10s
|
||||
```
|
||||
|
||||
### Strategy
|
||||
|
||||
This configures how Blocky startup works.
|
||||
The default strategy is blocking.
|
||||
|
||||
| strategy | Description |
|
||||
|-------------|------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| blocking | all sources are loaded before DNS resolution starts |
|
||||
| failOnError | like blocking but blocky will shut down if any source fails to load |
|
||||
| fast | blocky starts serving DNS immediately and sources are loaded asynchronously. The features requiring the sources should enable soon after |
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
loading:
|
||||
strategy: failOnError
|
||||
```
|
||||
|
||||
### Max Errors per Source
|
||||
|
||||
Number of errors allowed when parsing a source before it is considered invalid and parsing stops.
|
||||
A value of -1 disables the limit.
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
loading:
|
||||
maxErrorsPerSource: 10
|
||||
```
|
||||
|
||||
### Concurrency
|
||||
|
||||
Blocky downloads and processes sources concurrently. This allows limiting how many can be processed in the same time.
|
||||
Larger values can reduce the overall list refresh time at the cost of using more RAM. Please consider reducing this value on systems with limited memory.
|
||||
Default value is 4.
|
||||
|
||||
!!! example
|
||||
|
||||
```yaml
|
||||
loading:
|
||||
concurrency: 10
|
||||
```
|
||||
|
||||
!!! note
|
||||
|
||||
As with other settings under `loading`, the limit applies to the blocking and hosts file resolvers separately.
|
||||
The total number of concurrent sources concurrently processed can reach the sum of both values.
|
||||
For example if blocking has a limit set to 8 and hosts file's is 4, there could be up to 12 concurrent jobs.
|
||||
|
|
|
@ -19,7 +19,7 @@ var _ = Describe("External lists and query blocking", func() {
|
|||
})
|
||||
Describe("List download on startup", func() {
|
||||
When("external blacklist ist not available", func() {
|
||||
Context("startStrategy = blocking", func() {
|
||||
Context("loading.strategy = blocking", func() {
|
||||
BeforeEach(func() {
|
||||
blocky, err = createBlockyContainer(tmpDir,
|
||||
"log:",
|
||||
|
@ -28,7 +28,8 @@ var _ = Describe("External lists and query blocking", func() {
|
|||
" default:",
|
||||
" - moka",
|
||||
"blocking:",
|
||||
" startStrategy: blocking",
|
||||
" loading:",
|
||||
" strategy: blocking",
|
||||
" blackLists:",
|
||||
" ads:",
|
||||
" - http://wrong.domain.url/list.txt",
|
||||
|
@ -54,7 +55,7 @@ var _ = Describe("External lists and query blocking", func() {
|
|||
Expect(getContainerLogs(blocky)).Should(ContainElement(ContainSubstring("cannot open source: ")))
|
||||
})
|
||||
})
|
||||
Context("startStrategy = failOnError", func() {
|
||||
Context("loading.strategy = failOnError", func() {
|
||||
BeforeEach(func() {
|
||||
blocky, err = createBlockyContainer(tmpDir,
|
||||
"log:",
|
||||
|
@ -63,7 +64,8 @@ var _ = Describe("External lists and query blocking", func() {
|
|||
" default:",
|
||||
" - moka",
|
||||
"blocking:",
|
||||
" startStrategy: failOnError",
|
||||
" loading:",
|
||||
" strategy: failOnError",
|
||||
" blackLists:",
|
||||
" ads:",
|
||||
" - http://wrong.domain.url/list.txt",
|
||||
|
|
|
@ -29,7 +29,7 @@ const (
|
|||
// CachingDomainsToPrefetchCountChanged fires, if a number of domains being prefetched changed, Parameter: new count
|
||||
CachingDomainsToPrefetchCountChanged = "caching:domainsToPrefetchCountChanged"
|
||||
|
||||
// CachingFailedDownloadChanged fires, if a download of a blocking list fails
|
||||
// CachingFailedDownloadChanged fires, if a download of a blocking list or hosts file fails
|
||||
CachingFailedDownloadChanged = "caching:failedDownload"
|
||||
|
||||
// ApplicationStarted fires on start of the application. Parameter: version number, build time
|
||||
|
|
3
go.mod
3
go.mod
|
@ -37,6 +37,7 @@ require (
|
|||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0
|
||||
github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5
|
||||
github.com/docker/go-connections v0.4.0
|
||||
github.com/dosgo/zigtool v0.0.0-20210923085854-9c6fc1d62198
|
||||
github.com/testcontainers/testcontainers-go v0.21.0
|
||||
|
@ -56,7 +57,7 @@ require (
|
|||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/go-logr/logr v1.2.4 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
||||
github.com/google/pprof v0.0.0-20230309165930-d61513b1440d // indirect
|
||||
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
||||
github.com/klauspost/compress v1.11.13 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
|
|
8
go.sum
8
go.sum
|
@ -14,6 +14,8 @@ github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBa
|
|||
github.com/Microsoft/go-winio v0.5.2 h1:a9IhgEQBCUEk6QCdml9CiJGhAws+YwffDHEMp1VMrpA=
|
||||
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
|
||||
github.com/Microsoft/hcsshim v0.9.7 h1:mKNHW/Xvv1aFH87Jb6ERDzXTJTLPlmzfZ28VBFD/bfg=
|
||||
github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5 h1:3ubNg+3q/Y3lqxga0G90jste3i+HGDgrlPXK/feKUEI=
|
||||
github.com/ThinkChaos/parcour v0.0.0-20230418015731-5c82efbe68f5/go.mod h1:hkcYs23P9zbezt09v8168B4lt69PGuoxRPQ6IJHKpHo=
|
||||
github.com/abice/go-enum v0.5.6 h1:Ury51IQXUppbIl56MqRU/++A8SSeLG4plePphPjxW1s=
|
||||
github.com/abice/go-enum v0.5.6/go.mod h1:X2GpCT8VkCXLkVm48hebWx3cVgFJ8zM5nY5iUrJZO1Q=
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
|
||||
|
@ -108,8 +110,8 @@ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/pprof v0.0.0-20230309165930-d61513b1440d h1:um9/pc7tKMINFfP1eE7Wv6PRGXlcCSJkVajF7KJw3uQ=
|
||||
github.com/google/pprof v0.0.0-20230309165930-d61513b1440d/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk=
|
||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
|
@ -124,7 +126,6 @@ github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+l
|
|||
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
|
||||
github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4=
|
||||
github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
|
||||
github.com/imdario/mergo v0.3.15 h1:M8XP7IuFNsqUx6VPK2P9OSmsYsI/YFaGil0uD21V3dM=
|
||||
github.com/imdario/mergo v0.3.15/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
|
||||
|
@ -319,7 +320,6 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191115151921-52ab43148777/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
|
|
|
@ -6,18 +6,12 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/evt"
|
||||
"github.com/avast/retry-go/v4"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDownloadTimeout = time.Second
|
||||
defaultDownloadAttempts = uint(1)
|
||||
defaultDownloadCooldown = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// TransientError represents a temporary error like timeout, network errors...
|
||||
type TransientError struct {
|
||||
inner error
|
||||
|
@ -36,74 +30,35 @@ type FileDownloader interface {
|
|||
DownloadFile(link string) (io.ReadCloser, error)
|
||||
}
|
||||
|
||||
// HTTPDownloader downloads files via HTTP protocol
|
||||
type HTTPDownloader struct {
|
||||
downloadTimeout time.Duration
|
||||
downloadAttempts uint
|
||||
downloadCooldown time.Duration
|
||||
httpTransport *http.Transport
|
||||
// httpDownloader downloads files via HTTP protocol
|
||||
type httpDownloader struct {
|
||||
cfg config.DownloaderConfig
|
||||
|
||||
client http.Client
|
||||
}
|
||||
|
||||
type DownloaderOption func(c *HTTPDownloader)
|
||||
|
||||
func NewDownloader(options ...DownloaderOption) *HTTPDownloader {
|
||||
d := &HTTPDownloader{
|
||||
downloadTimeout: defaultDownloadTimeout,
|
||||
downloadAttempts: defaultDownloadAttempts,
|
||||
downloadCooldown: defaultDownloadCooldown,
|
||||
httpTransport: &http.Transport{},
|
||||
}
|
||||
|
||||
for _, opt := range options {
|
||||
opt(d)
|
||||
}
|
||||
|
||||
return d
|
||||
func NewDownloader(cfg config.DownloaderConfig, transport http.RoundTripper) FileDownloader {
|
||||
return newDownloader(cfg, transport)
|
||||
}
|
||||
|
||||
// WithTimeout sets the download timeout
|
||||
func WithTimeout(timeout time.Duration) DownloaderOption {
|
||||
return func(d *HTTPDownloader) {
|
||||
d.downloadTimeout = timeout
|
||||
func newDownloader(cfg config.DownloaderConfig, transport http.RoundTripper) *httpDownloader {
|
||||
return &httpDownloader{
|
||||
cfg: cfg,
|
||||
|
||||
client: http.Client{
|
||||
Transport: transport,
|
||||
Timeout: cfg.Timeout.ToDuration(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout sets the pause between 2 download attempts
|
||||
func WithCooldown(cooldown time.Duration) DownloaderOption {
|
||||
return func(d *HTTPDownloader) {
|
||||
d.downloadCooldown = cooldown
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout sets the attempt number for retry
|
||||
func WithAttempts(downloadAttempts uint) DownloaderOption {
|
||||
return func(d *HTTPDownloader) {
|
||||
d.downloadAttempts = downloadAttempts
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout sets the HTTP transport
|
||||
func WithTransport(httpTransport *http.Transport) DownloaderOption {
|
||||
return func(d *HTTPDownloader) {
|
||||
d.httpTransport = httpTransport
|
||||
}
|
||||
}
|
||||
|
||||
func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) {
|
||||
client := http.Client{
|
||||
Timeout: d.downloadTimeout,
|
||||
Transport: d.httpTransport,
|
||||
}
|
||||
|
||||
logger().WithField("link", link).Info("starting download")
|
||||
|
||||
func (d *httpDownloader) DownloadFile(link string) (io.ReadCloser, error) {
|
||||
var body io.ReadCloser
|
||||
|
||||
err := retry.Do(
|
||||
func() error {
|
||||
var resp *http.Response
|
||||
var httpErr error
|
||||
if resp, httpErr = client.Get(link); httpErr == nil {
|
||||
resp, httpErr := d.client.Get(link)
|
||||
if httpErr == nil {
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
body = resp.Body
|
||||
|
||||
|
@ -121,17 +76,18 @@ func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) {
|
|||
|
||||
return httpErr
|
||||
},
|
||||
retry.Attempts(d.downloadAttempts),
|
||||
retry.Attempts(d.cfg.Attempts),
|
||||
retry.DelayType(retry.FixedDelay),
|
||||
retry.Delay(d.downloadCooldown),
|
||||
retry.Delay(d.cfg.Cooldown.ToDuration()),
|
||||
retry.LastErrorOnly(true),
|
||||
retry.OnRetry(func(n uint, err error) {
|
||||
var transientErr *TransientError
|
||||
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
logger := logger().WithField("link", link).WithField("attempt",
|
||||
fmt.Sprintf("%d/%d", n+1, d.downloadAttempts))
|
||||
logger := logger().
|
||||
WithField("link", link).
|
||||
WithField("attempt", fmt.Sprintf("%d/%d", n+1, d.cfg.Attempts))
|
||||
|
||||
switch {
|
||||
case errors.As(err, &transientErr):
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/evt"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
|
@ -20,11 +21,17 @@ import (
|
|||
|
||||
var _ = Describe("Downloader", func() {
|
||||
var (
|
||||
sut *HTTPDownloader
|
||||
sutConfig config.DownloaderConfig
|
||||
sut *httpDownloader
|
||||
failedDownloadCountEvtChannel chan string
|
||||
loggerHook *test.Hook
|
||||
)
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
|
||||
sutConfig, err = config.WithDefaults[config.DownloaderConfig]()
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
failedDownloadCountEvtChannel = make(chan string, 5)
|
||||
// collect received events in the channel
|
||||
fn := func(url string) {
|
||||
|
@ -40,33 +47,27 @@ var _ = Describe("Downloader", func() {
|
|||
DeferCleanup(loggerHook.Reset)
|
||||
})
|
||||
|
||||
Describe("Construct downloader", func() {
|
||||
When("No options are provided", func() {
|
||||
BeforeEach(func() {
|
||||
sut = NewDownloader()
|
||||
})
|
||||
It("Should provide default valus", func() {
|
||||
Expect(sut.downloadAttempts).Should(BeNumerically("==", defaultDownloadAttempts))
|
||||
Expect(sut.downloadTimeout).Should(BeNumerically("==", defaultDownloadTimeout))
|
||||
Expect(sut.downloadCooldown).Should(BeNumerically("==", defaultDownloadCooldown))
|
||||
})
|
||||
})
|
||||
When("Options are provided", func() {
|
||||
JustBeforeEach(func() {
|
||||
sut = newDownloader(sutConfig, nil)
|
||||
})
|
||||
|
||||
Describe("NewDownloader", func() {
|
||||
It("Should use provided parameters", func() {
|
||||
transport := &http.Transport{}
|
||||
BeforeEach(func() {
|
||||
sut = NewDownloader(
|
||||
WithAttempts(5),
|
||||
WithCooldown(2*time.Second),
|
||||
WithTimeout(5*time.Second),
|
||||
WithTransport(transport),
|
||||
)
|
||||
})
|
||||
It("Should use provided parameters", func() {
|
||||
Expect(sut.downloadAttempts).Should(BeNumerically("==", 5))
|
||||
Expect(sut.downloadTimeout).Should(BeNumerically("==", 5*time.Second))
|
||||
Expect(sut.downloadCooldown).Should(BeNumerically("==", 2*time.Second))
|
||||
Expect(sut.httpTransport).Should(BeIdenticalTo(transport))
|
||||
})
|
||||
|
||||
sut = NewDownloader(
|
||||
config.DownloaderConfig{
|
||||
Attempts: 5,
|
||||
Cooldown: config.Duration(2 * time.Second),
|
||||
Timeout: config.Duration(5 * time.Second),
|
||||
},
|
||||
transport,
|
||||
).(*httpDownloader)
|
||||
|
||||
Expect(sut.cfg.Attempts).Should(BeNumerically("==", 5))
|
||||
Expect(sut.cfg.Timeout).Should(BeNumerically("==", 5*time.Second))
|
||||
Expect(sut.cfg.Cooldown).Should(BeNumerically("==", 2*time.Second))
|
||||
Expect(sut.client.Transport).Should(BeIdenticalTo(transport))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -77,7 +78,7 @@ var _ = Describe("Downloader", func() {
|
|||
server = TestServer("line.one\nline.two")
|
||||
DeferCleanup(server.Close)
|
||||
|
||||
sut = NewDownloader()
|
||||
sut = newDownloader(sutConfig, nil)
|
||||
})
|
||||
It("Should return all lines from the file", func() {
|
||||
reader, err := sut.DownloadFile(server.URL)
|
||||
|
@ -98,7 +99,7 @@ var _ = Describe("Downloader", func() {
|
|||
}))
|
||||
DeferCleanup(server.Close)
|
||||
|
||||
sut = NewDownloader(WithAttempts(3))
|
||||
sutConfig.Attempts = 3
|
||||
})
|
||||
It("Should return error", func() {
|
||||
reader, err := sut.DownloadFile(server.URL)
|
||||
|
@ -112,7 +113,7 @@ var _ = Describe("Downloader", func() {
|
|||
})
|
||||
When("Wrong URL is defined", func() {
|
||||
BeforeEach(func() {
|
||||
sut = NewDownloader()
|
||||
sutConfig.Attempts = 1
|
||||
})
|
||||
It("Should return error", func() {
|
||||
_, err := sut.DownloadFile("somewrongurl")
|
||||
|
@ -129,10 +130,11 @@ var _ = Describe("Downloader", func() {
|
|||
var attempt uint64 = 1
|
||||
|
||||
BeforeEach(func() {
|
||||
sut = NewDownloader(
|
||||
WithTimeout(20*time.Millisecond),
|
||||
WithAttempts(3),
|
||||
WithCooldown(time.Millisecond))
|
||||
sutConfig = config.DownloaderConfig{
|
||||
Timeout: config.Duration(20 * time.Millisecond),
|
||||
Attempts: 3,
|
||||
Cooldown: config.Duration(time.Millisecond),
|
||||
}
|
||||
|
||||
// should produce a timeout on first attempt
|
||||
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
@ -166,24 +168,23 @@ var _ = Describe("Downloader", func() {
|
|||
})
|
||||
When("If timeout occurs on all request", func() {
|
||||
BeforeEach(func() {
|
||||
sut = NewDownloader(
|
||||
WithTimeout(100*time.Millisecond),
|
||||
WithAttempts(3),
|
||||
WithCooldown(time.Millisecond))
|
||||
sutConfig = config.DownloaderConfig{
|
||||
Timeout: config.Duration(10 * time.Millisecond),
|
||||
Attempts: 3,
|
||||
Cooldown: config.Duration(time.Millisecond),
|
||||
}
|
||||
|
||||
// should always produce a timeout
|
||||
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}))
|
||||
DeferCleanup(server.Close)
|
||||
})
|
||||
It("Should perform a retry until max retry attempt count is reached and return TransientError", func() {
|
||||
reader, err := sut.DownloadFile(server.URL)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
|
||||
err2 := unwrapTransientErr(err)
|
||||
|
||||
Expect(err2.Error()).Should(ContainSubstring("Timeout"))
|
||||
Expect(errors.As(err, new(*TransientError))).Should(BeTrue())
|
||||
Expect(err.Error()).Should(ContainSubstring("Timeout"))
|
||||
Expect(reader).Should(BeNil())
|
||||
|
||||
// failed download event was emitted 3 times
|
||||
|
@ -193,19 +194,18 @@ var _ = Describe("Downloader", func() {
|
|||
})
|
||||
When("DNS resolution of passed URL fails", func() {
|
||||
BeforeEach(func() {
|
||||
sut = NewDownloader(
|
||||
WithTimeout(500*time.Millisecond),
|
||||
WithAttempts(3),
|
||||
WithCooldown(200*time.Millisecond))
|
||||
sutConfig = config.DownloaderConfig{
|
||||
Timeout: config.Duration(500 * time.Millisecond),
|
||||
Attempts: 3,
|
||||
Cooldown: 200 * config.Duration(time.Millisecond),
|
||||
}
|
||||
})
|
||||
It("Should perform a retry until max retry attempt count is reached and return DNSError", func() {
|
||||
reader, err := sut.DownloadFile("http://some.domain.which.does.not.exist")
|
||||
Expect(err).Should(HaveOccurred())
|
||||
|
||||
err2 := unwrapTransientErr(err)
|
||||
|
||||
var dnsError *net.DNSError
|
||||
Expect(errors.As(err2, &dnsError)).To(BeTrue(), "received error %w", err)
|
||||
Expect(errors.As(err, &dnsError)).Should(BeTrue(), "received error %w", err)
|
||||
Expect(reader).Should(BeNil())
|
||||
|
||||
// failed download event was emitted 3 times
|
||||
|
@ -216,12 +216,3 @@ var _ = Describe("Downloader", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
func unwrapTransientErr(origErr error) error {
|
||||
var transientErr *TransientError
|
||||
if errors.As(origErr, &transientErr) {
|
||||
return transientErr.Unwrap()
|
||||
}
|
||||
|
||||
return origErr
|
||||
}
|
||||
|
|
|
@ -5,25 +5,20 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/0xERR0R/blocky/cache/stringcache"
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/evt"
|
||||
"github.com/0xERR0R/blocky/lists/parsers"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/ThinkChaos/parcour"
|
||||
"github.com/ThinkChaos/parcour/jobgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultProcessingConcurrency = 4
|
||||
chanCap = 1000
|
||||
)
|
||||
const groupProducersBufferCap = 1000
|
||||
|
||||
// ListCacheType represents the type of cached list ENUM(
|
||||
// blacklist // is a list with blocked domains
|
||||
|
@ -41,19 +36,17 @@ type Matcher interface {
|
|||
type ListCache struct {
|
||||
groupedCache stringcache.GroupedStringCache
|
||||
|
||||
groupToLinks map[string][]string
|
||||
refreshPeriod time.Duration
|
||||
downloader FileDownloader
|
||||
listType ListCacheType
|
||||
processingConcurrency uint
|
||||
maxErrorsPerFile int
|
||||
cfg config.SourceLoadingConfig
|
||||
listType ListCacheType
|
||||
groupSources map[string][]config.BytesSource
|
||||
downloader FileDownloader
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (b *ListCache) LogConfig(logger *logrus.Entry) {
|
||||
var total int
|
||||
|
||||
for group := range b.groupToLinks {
|
||||
for group := range b.groupSources {
|
||||
count := b.groupedCache.ElementCount(group)
|
||||
logger.Infof("%s: %d entries", group, count)
|
||||
total += count
|
||||
|
@ -63,132 +56,36 @@ func (b *ListCache) LogConfig(logger *logrus.Entry) {
|
|||
}
|
||||
|
||||
// NewListCache creates new list instance
|
||||
func NewListCache(t ListCacheType, groupToLinks map[string][]string, refreshPeriod time.Duration,
|
||||
downloader FileDownloader, processingConcurrency uint, async bool, maxErrorsPerFile int,
|
||||
func NewListCache(
|
||||
t ListCacheType, cfg config.SourceLoadingConfig,
|
||||
groupSources map[string][]config.BytesSource, downloader FileDownloader,
|
||||
) (*ListCache, error) {
|
||||
if processingConcurrency == 0 {
|
||||
processingConcurrency = defaultProcessingConcurrency
|
||||
}
|
||||
|
||||
b := &ListCache{
|
||||
c := &ListCache{
|
||||
groupedCache: stringcache.NewChainedGroupedCache(
|
||||
stringcache.NewInMemoryGroupedStringCache(),
|
||||
stringcache.NewInMemoryGroupedRegexCache(),
|
||||
),
|
||||
groupToLinks: groupToLinks,
|
||||
refreshPeriod: refreshPeriod,
|
||||
downloader: downloader,
|
||||
listType: t,
|
||||
processingConcurrency: processingConcurrency,
|
||||
maxErrorsPerFile: maxErrorsPerFile,
|
||||
|
||||
cfg: cfg,
|
||||
listType: t,
|
||||
groupSources: groupSources,
|
||||
downloader: downloader,
|
||||
}
|
||||
|
||||
var initError error
|
||||
if async {
|
||||
initError = nil
|
||||
|
||||
// start list refresh in the background
|
||||
go b.Refresh()
|
||||
} else {
|
||||
initError = b.refresh(true)
|
||||
err := cfg.StartPeriodicRefresh(c.refresh, func(err error) {
|
||||
logger().WithError(err).Errorf("could not init %s", t)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if initError == nil {
|
||||
go periodicUpdate(b)
|
||||
}
|
||||
|
||||
return b, initError
|
||||
}
|
||||
|
||||
// periodicUpdate triggers periodical refresh (and download) of list entries
|
||||
func periodicUpdate(cache *ListCache) {
|
||||
if cache.refreshPeriod > 0 {
|
||||
ticker := time.NewTicker(cache.refreshPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
<-ticker.C
|
||||
cache.Refresh()
|
||||
}
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func logger() *logrus.Entry {
|
||||
return log.PrefixedLog("list_cache")
|
||||
}
|
||||
|
||||
// downloads and reads files with domain names and creates cache for them
|
||||
//
|
||||
//nolint:funlen // will refactor in a later commit
|
||||
func (b *ListCache) createCacheForGroup(group string, links []string) (created bool, err error) {
|
||||
groupFactory := b.groupedCache.Refresh(group)
|
||||
|
||||
fileLinesChan := make(chan string, chanCap)
|
||||
errChan := make(chan error, chanCap)
|
||||
|
||||
workerDoneChan := make(chan bool, len(links))
|
||||
|
||||
// guard channel is used to limit the number of concurrent executions of the function
|
||||
guard := make(chan struct{}, b.processingConcurrency)
|
||||
|
||||
processingLinkJobs := len(links)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// loop over links (http/local) or inline definitions
|
||||
// start a new goroutine for each link, but limit to max. number (see processingConcurrency)
|
||||
for idx, link := range links {
|
||||
go func(idx int, link string) {
|
||||
// try to write in this channel -> this will block if max amount of goroutines are being executed
|
||||
guard <- struct{}{}
|
||||
|
||||
defer func() {
|
||||
// remove from guard channel to allow other blocked goroutines to continue
|
||||
<-guard
|
||||
workerDoneChan <- true
|
||||
}()
|
||||
|
||||
name := linkName(idx, link)
|
||||
|
||||
err := b.parseFile(ctx, name, link, fileLinesChan)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}(idx, link)
|
||||
}
|
||||
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case line := <-fileLinesChan:
|
||||
groupFactory.AddEntry(line)
|
||||
case e := <-errChan:
|
||||
var transientErr *TransientError
|
||||
|
||||
if errors.As(e, &transientErr) {
|
||||
return false, e
|
||||
}
|
||||
err = multierror.Append(err, e)
|
||||
case <-workerDoneChan:
|
||||
processingLinkJobs--
|
||||
|
||||
default:
|
||||
if processingLinkJobs == 0 {
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if groupFactory.Count() == 0 && err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
groupFactory.Finish()
|
||||
|
||||
return true, err
|
||||
}
|
||||
|
||||
// Match matches passed domain name against cached list entries
|
||||
func (b *ListCache) Match(domain string, groupsToCheck []string) (groups []string) {
|
||||
return b.groupedCache.Contains(domain, groupsToCheck)
|
||||
|
@ -196,65 +93,123 @@ func (b *ListCache) Match(domain string, groupsToCheck []string) (groups []strin
|
|||
|
||||
// Refresh triggers the refresh of a list
|
||||
func (b *ListCache) Refresh() {
|
||||
_ = b.refresh(false)
|
||||
_ = b.refresh(context.Background())
|
||||
}
|
||||
|
||||
func (b *ListCache) refresh(isInit bool) error {
|
||||
var err error
|
||||
func (b *ListCache) refresh(ctx context.Context) error {
|
||||
unlimitedGrp, _ := jobgroup.WithContext(ctx)
|
||||
defer unlimitedGrp.Close()
|
||||
|
||||
for group, links := range b.groupToLinks {
|
||||
created, e := b.createCacheForGroup(group, links)
|
||||
if e != nil {
|
||||
err = multierror.Append(err, multierror.Prefix(e, fmt.Sprintf("can't create cache group '%s':", group)))
|
||||
}
|
||||
producersGrp := jobgroup.WithMaxConcurrency(unlimitedGrp, b.cfg.Concurrency)
|
||||
defer producersGrp.Close()
|
||||
|
||||
count := b.groupedCache.ElementCount(group)
|
||||
for group, sources := range b.groupSources {
|
||||
group, sources := group, sources
|
||||
|
||||
if !created {
|
||||
logger := logger().WithFields(logrus.Fields{
|
||||
"group": group,
|
||||
"total_count": count,
|
||||
})
|
||||
unlimitedGrp.Go(func(ctx context.Context) error {
|
||||
err := b.createCacheForGroup(producersGrp, unlimitedGrp, group, sources)
|
||||
if err != nil {
|
||||
count := b.groupedCache.ElementCount(group)
|
||||
|
||||
if count == 0 || isInit {
|
||||
logger.Warn("Populating of group cache failed, cache will be empty until refresh succeeds")
|
||||
} else {
|
||||
logger.Warn("Populating of group cache failed, using existing cache, if any")
|
||||
logger := logger().WithFields(logrus.Fields{
|
||||
"group": group,
|
||||
"total_count": count,
|
||||
})
|
||||
|
||||
if count == 0 {
|
||||
logger.Warn("Populating of group cache failed, cache will be empty until refresh succeeds")
|
||||
} else {
|
||||
logger.Warn("Populating of group cache failed, using existing cache, if any")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
count := b.groupedCache.ElementCount(group)
|
||||
|
||||
evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, count)
|
||||
evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, count)
|
||||
|
||||
logger().WithFields(logrus.Fields{
|
||||
"group": group,
|
||||
"total_count": count,
|
||||
}).Info("group import finished")
|
||||
logger().WithFields(logrus.Fields{
|
||||
"group": group,
|
||||
"total_count": count,
|
||||
}).Info("group import finished")
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return err
|
||||
return unlimitedGrp.Wait()
|
||||
}
|
||||
|
||||
func readFile(file string) (io.ReadCloser, error) {
|
||||
logger().WithField("file", file).Info("starting processing of file")
|
||||
file = strings.TrimPrefix(file, "file://")
|
||||
func (b *ListCache) createCacheForGroup(
|
||||
producersGrp, consumersGrp jobgroup.JobGroup, group string, sources []config.BytesSource,
|
||||
) error {
|
||||
groupFactory := b.groupedCache.Refresh(group)
|
||||
|
||||
return os.Open(file)
|
||||
producers := parcour.NewProducersWithBuffer[string](producersGrp, consumersGrp, groupProducersBufferCap)
|
||||
defer producers.Close()
|
||||
|
||||
for i, source := range sources {
|
||||
i, source := i, source
|
||||
|
||||
producers.GoProduce(func(ctx context.Context, hostsChan chan<- string) error {
|
||||
locInfo := fmt.Sprintf("item #%d of group %s", i, group)
|
||||
|
||||
opener, err := NewSourceOpener(locInfo, source, b.downloader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return b.parseFile(ctx, opener, hostsChan)
|
||||
})
|
||||
}
|
||||
|
||||
hasEntries := false
|
||||
|
||||
producers.GoConsume(func(ctx context.Context, ch <-chan string) error {
|
||||
for host := range ch {
|
||||
hasEntries = true
|
||||
|
||||
groupFactory.AddEntry(host)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
err := producers.Wait()
|
||||
if err != nil {
|
||||
if !hasEntries {
|
||||
// Always fail the group if no entries were parsed
|
||||
return err
|
||||
}
|
||||
|
||||
var transientErr *TransientError
|
||||
|
||||
if errors.As(err, &transientErr) {
|
||||
// Temporary error: fail the whole group to retry later
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
groupFactory.Finish()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// downloads file (or reads local file) and writes each line in the file to the result channel
|
||||
func (b *ListCache) parseFile(ctx context.Context, name, link string, resultCh chan<- string) error {
|
||||
func (b *ListCache) parseFile(ctx context.Context, opener SourceOpener, resultCh chan<- string) error {
|
||||
count := 0
|
||||
|
||||
logger := func() *logrus.Entry {
|
||||
return logger().WithFields(logrus.Fields{
|
||||
"source": name,
|
||||
"source": opener.String(),
|
||||
"count": count,
|
||||
})
|
||||
}
|
||||
|
||||
r, err := b.newLinkReader(link)
|
||||
logger().Debug("starting processing of source")
|
||||
|
||||
r, err := opener.Open()
|
||||
if err != nil {
|
||||
logger().Error("cannot open source: ", err)
|
||||
|
||||
|
@ -262,7 +217,7 @@ func (b *ListCache) parseFile(ctx context.Context, name, link string, resultCh c
|
|||
}
|
||||
defer r.Close()
|
||||
|
||||
p := parsers.AllowErrors(parsers.Hosts(r), b.maxErrorsPerFile)
|
||||
p := parsers.AllowErrors(parsers.Hosts(r), b.cfg.MaxErrorsPerSource)
|
||||
p.OnErr(func(err error) {
|
||||
logger().Warnf("parse error: %s, trying to continue", err)
|
||||
})
|
||||
|
@ -303,27 +258,3 @@ func (b *ListCache) parseFile(ctx context.Context, name, link string, resultCh c
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func linkName(linkIdx int, link string) string {
|
||||
if strings.ContainsAny(link, "\n") {
|
||||
return fmt.Sprintf("inline block (item #%d in group)", linkIdx)
|
||||
}
|
||||
|
||||
return link
|
||||
}
|
||||
|
||||
func (b *ListCache) newLinkReader(link string) (r io.ReadCloser, err error) {
|
||||
switch {
|
||||
// link contains a line break -> this is inline list definition in YAML (with literal style Block Scalar)
|
||||
case strings.ContainsAny(link, "\n"):
|
||||
r = io.NopCloser(strings.NewReader(link))
|
||||
// link is http(s) -> download it
|
||||
case strings.HasPrefix(link, "http"):
|
||||
r, err = b.downloader.DownloadFile(link)
|
||||
// probably path to a local file
|
||||
default:
|
||||
r, err = readFile(link)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -2,17 +2,24 @@ package lists
|
|||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
)
|
||||
|
||||
func BenchmarkRefresh(b *testing.B) {
|
||||
file1, _ := createTestListFile(b.TempDir(), 100000)
|
||||
file2, _ := createTestListFile(b.TempDir(), 150000)
|
||||
file3, _ := createTestListFile(b.TempDir(), 130000)
|
||||
lists := map[string][]string{
|
||||
"gr1": {file1, file2, file3},
|
||||
lists := map[string][]config.BytesSource{
|
||||
"gr1": config.NewBytesSources(file1, file2, file3),
|
||||
}
|
||||
|
||||
cache, _ := NewListCache(ListCacheTypeBlacklist, lists, -1, NewDownloader(), 5, false, 5)
|
||||
cfg := config.SourceLoadingConfig{
|
||||
Concurrency: 5,
|
||||
RefreshPeriod: config.Duration(-1),
|
||||
}
|
||||
downloader := NewDownloader(config.DownloaderConfig{}, nil)
|
||||
cache, _ := NewListCache(ListCacheTypeBlacklist, cfg, lists, downloader)
|
||||
|
||||
b.ReportAllocs()
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package lists
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -9,8 +10,8 @@ import (
|
|||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/evt"
|
||||
"github.com/0xERR0R/blocky/lists/parsers"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
|
@ -27,13 +28,28 @@ var _ = Describe("ListCache", func() {
|
|||
tmpDir *TmpFolder
|
||||
emptyFile, file1, file2, file3 *TmpFile
|
||||
server1, server2, server3 *httptest.Server
|
||||
maxErrorsPerFile int
|
||||
|
||||
sut *ListCache
|
||||
sutConfig config.SourceLoadingConfig
|
||||
|
||||
listCacheType ListCacheType
|
||||
lists map[string][]config.BytesSource
|
||||
downloader FileDownloader
|
||||
mockDownloader *MockDownloader
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
maxErrorsPerFile = 5
|
||||
tmpDir = NewTmpFolder("ListCache")
|
||||
Expect(tmpDir.Error).Should(Succeed())
|
||||
DeferCleanup(tmpDir.Clean)
|
||||
var err error
|
||||
|
||||
listCacheType = ListCacheTypeBlacklist
|
||||
|
||||
sutConfig, err = config.WithDefaults[config.SourceLoadingConfig]()
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
sutConfig.RefreshPeriod = -1
|
||||
|
||||
downloader = NewDownloader(config.DownloaderConfig{}, nil)
|
||||
mockDownloader = nil
|
||||
|
||||
server1 = TestServer("blocked1.com\nblocked1a.com\n192.168.178.55")
|
||||
DeferCleanup(server1.Close)
|
||||
|
@ -42,6 +58,13 @@ var _ = Describe("ListCache", func() {
|
|||
server3 = TestServer("blocked3.com\nblocked1a.com")
|
||||
DeferCleanup(server3.Close)
|
||||
|
||||
tmpDir = NewTmpFolder("ListCache")
|
||||
Expect(tmpDir.Error).Should(Succeed())
|
||||
DeferCleanup(tmpDir.Clean)
|
||||
|
||||
emptyFile = tmpDir.CreateStringFile("empty", "#empty file")
|
||||
Expect(emptyFile.Error).Should(Succeed())
|
||||
|
||||
emptyFile = tmpDir.CreateStringFile("empty", "#empty file")
|
||||
Expect(emptyFile.Error).Should(Succeed())
|
||||
file1 = tmpDir.CreateStringFile("file1", "blocked1.com", "blocked1a.com")
|
||||
|
@ -52,61 +75,56 @@ var _ = Describe("ListCache", func() {
|
|||
Expect(file3.Error).Should(Succeed())
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
var err error
|
||||
|
||||
Expect(lists).ShouldNot(BeNil(), "bad test: forgot to set `lists`")
|
||||
|
||||
if mockDownloader != nil {
|
||||
downloader = mockDownloader
|
||||
}
|
||||
|
||||
sut, err = NewListCache(listCacheType, sutConfig, lists, downloader)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
|
||||
Describe("List cache and matching", func() {
|
||||
When("Query with empty", func() {
|
||||
It("should not panic", func() {
|
||||
lists := map[string][]string{
|
||||
"gr0": {emptyFile.Path},
|
||||
}
|
||||
sut, err := NewListCache(
|
||||
ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile,
|
||||
)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
group := sut.Match("", []string{"gr0"})
|
||||
Expect(group).Should(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("List is empty", func() {
|
||||
It("should not match anything", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {emptyFile.Path},
|
||||
BeforeEach(func() {
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr0": config.NewBytesSources(emptyFile.Path),
|
||||
}
|
||||
sut, err := NewListCache(
|
||||
ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile,
|
||||
)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
|
||||
When("Query with empty", func() {
|
||||
It("should not panic", func() {
|
||||
group := sut.Match("", []string{"gr0"})
|
||||
Expect(group).Should(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
It("should not match anything", func() {
|
||||
group := sut.Match("google.com", []string{"gr1"})
|
||||
Expect(group).Should(BeEmpty())
|
||||
})
|
||||
})
|
||||
When("List becomes empty on refresh", func() {
|
||||
It("should delete existing elements from group cache", func() {
|
||||
mockDownloader := newMockDownloader(func(res chan<- string, err chan<- error) {
|
||||
BeforeEach(func() {
|
||||
mockDownloader = newMockDownloader(func(res chan<- string, err chan<- error) {
|
||||
res <- "blocked1.com"
|
||||
res <- "# nothing"
|
||||
})
|
||||
|
||||
lists := map[string][]string{
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr1": {mockDownloader.ListSource()},
|
||||
}
|
||||
})
|
||||
|
||||
sut, err := NewListCache(
|
||||
ListCacheTypeBlacklist, lists,
|
||||
4*time.Hour,
|
||||
mockDownloader,
|
||||
defaultProcessingConcurrency,
|
||||
false,
|
||||
maxErrorsPerFile,
|
||||
)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
It("should delete existing elements from group cache", func(ctx context.Context) {
|
||||
group := sut.Match("blocked1.com", []string{"gr1"})
|
||||
Expect(group).Should(ContainElement("gr1"))
|
||||
|
||||
err = sut.refresh(false)
|
||||
err := sut.refresh(ctx)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
group = sut.Match("blocked1.com", []string{"gr1"})
|
||||
|
@ -114,21 +132,19 @@ var _ = Describe("ListCache", func() {
|
|||
})
|
||||
})
|
||||
When("List has invalid lines", func() {
|
||||
It("should still other domains", func() {
|
||||
lists := map[string][]string{
|
||||
BeforeEach(func() {
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr1": {
|
||||
inlineList(
|
||||
config.TextBytesSource(
|
||||
"inlinedomain1.com",
|
||||
"invaliddomain!",
|
||||
"inlinedomain2.com",
|
||||
),
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
|
||||
defaultProcessingConcurrency, false, maxErrorsPerFile)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
It("should still other domains", func() {
|
||||
group := sut.Match("inlinedomain1.com", []string{"gr1"})
|
||||
Expect(group).Should(ContainElement("gr1"))
|
||||
|
||||
|
@ -137,28 +153,20 @@ var _ = Describe("ListCache", func() {
|
|||
})
|
||||
})
|
||||
When("a temporary/transient err occurs on download", func() {
|
||||
It("should not delete existing elements from group cache", func() {
|
||||
BeforeEach(func() {
|
||||
// should produce a transient error on second and third attempt
|
||||
mockDownloader := newMockDownloader(func(res chan<- string, err chan<- error) {
|
||||
res <- "blocked1.com"
|
||||
mockDownloader = newMockDownloader(func(res chan<- string, err chan<- error) {
|
||||
res <- "blocked1.com\nblocked2.com\n"
|
||||
err <- &TransientError{inner: errors.New("boom")}
|
||||
err <- &TransientError{inner: errors.New("boom")}
|
||||
})
|
||||
|
||||
lists := map[string][]string{
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr1": {mockDownloader.ListSource()},
|
||||
}
|
||||
})
|
||||
|
||||
sut, err := NewListCache(
|
||||
ListCacheTypeBlacklist, lists,
|
||||
4*time.Hour,
|
||||
mockDownloader,
|
||||
defaultProcessingConcurrency,
|
||||
false,
|
||||
maxErrorsPerFile,
|
||||
)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
It("should not delete existing elements from group cache", func(ctx context.Context) {
|
||||
By("Lists loaded without timeout", func() {
|
||||
Eventually(func(g Gomega) {
|
||||
group := sut.Match("blocked1.com", []string{"gr1"})
|
||||
|
@ -166,7 +174,7 @@ var _ = Describe("ListCache", func() {
|
|||
}, "1s").Should(Succeed())
|
||||
})
|
||||
|
||||
Expect(sut.refresh(false)).Should(HaveOccurred())
|
||||
Expect(sut.refresh(ctx)).Should(HaveOccurred())
|
||||
|
||||
By("List couldn't be loaded due to timeout", func() {
|
||||
group := sut.Match("blocked1.com", []string{"gr1"})
|
||||
|
@ -182,27 +190,25 @@ var _ = Describe("ListCache", func() {
|
|||
})
|
||||
})
|
||||
When("non transient err occurs on download", func() {
|
||||
It("should keep existing elements from group cache", func() {
|
||||
BeforeEach(func() {
|
||||
// should produce a non transient error on second attempt
|
||||
mockDownloader := newMockDownloader(func(res chan<- string, err chan<- error) {
|
||||
mockDownloader = newMockDownloader(func(res chan<- string, err chan<- error) {
|
||||
res <- "blocked1.com"
|
||||
err <- errors.New("boom")
|
||||
})
|
||||
|
||||
lists := map[string][]string{
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr1": {mockDownloader.ListSource()},
|
||||
}
|
||||
})
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, mockDownloader,
|
||||
defaultProcessingConcurrency, false, maxErrorsPerFile)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
It("should keep existing elements from group cache", func(ctx context.Context) {
|
||||
By("Lists loaded without err", func() {
|
||||
group := sut.Match("blocked1.com", []string{"gr1"})
|
||||
Expect(group).Should(ContainElement("gr1"))
|
||||
})
|
||||
|
||||
Expect(sut.refresh(false)).Should(HaveOccurred())
|
||||
Expect(sut.refresh(ctx)).Should(HaveOccurred())
|
||||
|
||||
By("Lists from first load is kept", func() {
|
||||
group := sut.Match("blocked1.com", []string{"gr1"})
|
||||
|
@ -211,16 +217,14 @@ var _ = Describe("ListCache", func() {
|
|||
})
|
||||
})
|
||||
When("Configuration has 3 external working urls", func() {
|
||||
It("should download the list and match against", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {server1.URL, server2.URL},
|
||||
"gr2": {server3.URL},
|
||||
BeforeEach(func() {
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr1": config.NewBytesSources(server1.URL, server2.URL),
|
||||
"gr2": config.NewBytesSources(server3.URL),
|
||||
}
|
||||
})
|
||||
|
||||
sut, _ := NewListCache(
|
||||
ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile,
|
||||
)
|
||||
|
||||
It("should download the list and match against", func() {
|
||||
group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
|
||||
Expect(group).Should(ContainElement("gr1"))
|
||||
|
||||
|
@ -232,16 +236,14 @@ var _ = Describe("ListCache", func() {
|
|||
})
|
||||
})
|
||||
When("Configuration has some faulty urls", func() {
|
||||
It("should download the list and match against", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {server1.URL, server2.URL, "doesnotexist"},
|
||||
"gr2": {server3.URL, "someotherfile"},
|
||||
BeforeEach(func() {
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr1": config.NewBytesSources(server1.URL, server2.URL, "doesnotexist"),
|
||||
"gr2": config.NewBytesSources(server3.URL, "someotherfile"),
|
||||
}
|
||||
})
|
||||
|
||||
sut, _ := NewListCache(
|
||||
ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile,
|
||||
)
|
||||
|
||||
It("should download the list and match against", func() {
|
||||
group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
|
||||
Expect(group).Should(ContainElement("gr1"))
|
||||
|
||||
|
@ -253,39 +255,33 @@ var _ = Describe("ListCache", func() {
|
|||
})
|
||||
})
|
||||
When("List will be updated", func() {
|
||||
It("event should be fired and contain count of elements in downloaded lists", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {server1.URL},
|
||||
}
|
||||
resultCnt := 0
|
||||
|
||||
resultCnt := 0
|
||||
BeforeEach(func() {
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr1": config.NewBytesSources(server1.URL),
|
||||
}
|
||||
|
||||
_ = Bus().SubscribeOnce(BlockingCacheGroupChanged, func(listType ListCacheType, group string, cnt int) {
|
||||
resultCnt = cnt
|
||||
})
|
||||
})
|
||||
|
||||
sut, err := NewListCache(
|
||||
ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile,
|
||||
)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
It("event should be fired and contain count of elements in downloaded lists", func() {
|
||||
group := sut.Match("blocked1.com", []string{})
|
||||
Expect(group).Should(BeEmpty())
|
||||
Expect(resultCnt).Should(Equal(3))
|
||||
})
|
||||
})
|
||||
When("multiple groups are passed", func() {
|
||||
It("should match", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {file1.Path, file2.Path},
|
||||
"gr2": {"file://" + file3.Path},
|
||||
BeforeEach(func() {
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr1": config.NewBytesSources(file1.Path, file2.Path),
|
||||
"gr2": config.NewBytesSources("file://" + file3.Path),
|
||||
}
|
||||
})
|
||||
|
||||
sut, err := NewListCache(
|
||||
ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false, maxErrorsPerFile,
|
||||
)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
It("should match", func() {
|
||||
Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(3))
|
||||
Expect(sut.groupedCache.ElementCount("gr2")).Should(Equal(2))
|
||||
|
||||
|
@ -304,31 +300,28 @@ var _ = Describe("ListCache", func() {
|
|||
file1, lines1 := createTestListFile(GinkgoT().TempDir(), 10000)
|
||||
file2, lines2 := createTestListFile(GinkgoT().TempDir(), 15000)
|
||||
file3, lines3 := createTestListFile(GinkgoT().TempDir(), 13000)
|
||||
lists := map[string][]string{
|
||||
"gr1": {file1, file2, file3},
|
||||
lists := map[string][]config.BytesSource{
|
||||
"gr1": config.NewBytesSources(file1, file2, file3),
|
||||
}
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
|
||||
defaultProcessingConcurrency, false, maxErrorsPerFile)
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(lines1 + lines2 + lines3))
|
||||
})
|
||||
})
|
||||
When("inline list content is defined", func() {
|
||||
It("should match", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {inlineList(
|
||||
BeforeEach(func() {
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr1": {config.TextBytesSource(
|
||||
"inlinedomain1.com",
|
||||
"#some comment",
|
||||
"inlinedomain2.com",
|
||||
)},
|
||||
}
|
||||
})
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
|
||||
defaultProcessingConcurrency, false, maxErrorsPerFile)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
It("should match", func() {
|
||||
Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(2))
|
||||
group := sut.Match("inlinedomain1.com", []string{"gr1"})
|
||||
Expect(group).Should(ContainElement("gr1"))
|
||||
|
@ -338,65 +331,59 @@ var _ = Describe("ListCache", func() {
|
|||
})
|
||||
})
|
||||
When("Text file can't be parsed", func() {
|
||||
It("should still match already imported strings", func() {
|
||||
lists := map[string][]string{
|
||||
BeforeEach(func() {
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr1": {
|
||||
inlineList(
|
||||
config.TextBytesSource(
|
||||
"inlinedomain1.com",
|
||||
"lineTooLong"+strings.Repeat("x", bufio.MaxScanTokenSize), // too long
|
||||
),
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
|
||||
defaultProcessingConcurrency, false, maxErrorsPerFile)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
It("should still match already imported strings", func() {
|
||||
group := sut.Match("inlinedomain1.com", []string{"gr1"})
|
||||
Expect(group).Should(ContainElement("gr1"))
|
||||
})
|
||||
})
|
||||
When("Text file has too many errors", func() {
|
||||
BeforeEach(func() {
|
||||
maxErrorsPerFile = 0
|
||||
sutConfig.MaxErrorsPerSource = 0
|
||||
sutConfig.Strategy = config.StartStrategyTypeFailOnError
|
||||
})
|
||||
It("should fail parsing", func() {
|
||||
lists := map[string][]string{
|
||||
lists := map[string][]config.BytesSource{
|
||||
"gr1": {
|
||||
inlineList("invaliddomain!"), // too many errors since `maxErrorsPerFile` is 0
|
||||
config.TextBytesSource("invaliddomain!"), // too many errors since `maxErrorsPerSource` is 0
|
||||
},
|
||||
}
|
||||
|
||||
_, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
|
||||
defaultProcessingConcurrency, false, maxErrorsPerFile)
|
||||
_, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(err).Should(MatchError(parsers.ErrTooManyErrors))
|
||||
})
|
||||
})
|
||||
When("file has end of line comment", func() {
|
||||
It("should still parse the domain", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {inlineList("inlinedomain1.com#a comment")},
|
||||
BeforeEach(func() {
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr1": {config.TextBytesSource("inlinedomain1.com#a comment")},
|
||||
}
|
||||
})
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
|
||||
defaultProcessingConcurrency, false, maxErrorsPerFile)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
It("should still parse the domain", func() {
|
||||
group := sut.Match("inlinedomain1.com", []string{"gr1"})
|
||||
Expect(group).Should(ContainElement("gr1"))
|
||||
})
|
||||
})
|
||||
When("inline regex content is defined", func() {
|
||||
It("should match", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {inlineList("/^apple\\.(de|com)$/")},
|
||||
BeforeEach(func() {
|
||||
lists = map[string][]config.BytesSource{
|
||||
"gr1": {config.TextBytesSource("/^apple\\.(de|com)$/")},
|
||||
}
|
||||
})
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
|
||||
defaultProcessingConcurrency, false, maxErrorsPerFile)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
It("should match", func() {
|
||||
group := sut.Match("apple.com", []string{"gr1"})
|
||||
Expect(group).Should(ContainElement("gr1"))
|
||||
|
||||
|
@ -416,13 +403,12 @@ var _ = Describe("ListCache", func() {
|
|||
})
|
||||
|
||||
It("should print list configuration", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {server1.URL, server2.URL},
|
||||
"gr2": {inlineList("inline", "definition")},
|
||||
lists := map[string][]config.BytesSource{
|
||||
"gr1": config.NewBytesSources(server1.URL, server2.URL),
|
||||
"gr2": {config.TextBytesSource("inline", "definition")},
|
||||
}
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, time.Hour, NewDownloader(),
|
||||
defaultProcessingConcurrency, false, maxErrorsPerFile)
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
@ -435,13 +421,16 @@ var _ = Describe("ListCache", func() {
|
|||
|
||||
Describe("StartStrategy", func() {
|
||||
When("async load is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig.Strategy = config.StartStrategyTypeFast
|
||||
})
|
||||
|
||||
It("should never return an error", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {"doesnotexist"},
|
||||
lists := map[string][]config.BytesSource{
|
||||
"gr1": config.NewBytesSources("doesnotexist"),
|
||||
}
|
||||
|
||||
_, err := NewListCache(ListCacheTypeBlacklist, lists, -1, NewDownloader(),
|
||||
defaultProcessingConcurrency, true, maxErrorsPerFile)
|
||||
_, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
})
|
||||
|
@ -465,8 +454,11 @@ func (m *MockDownloader) DownloadFile(_ string) (io.ReadCloser, error) {
|
|||
return io.NopCloser(strings.NewReader(str)), nil
|
||||
}
|
||||
|
||||
func (m *MockDownloader) ListSource() string {
|
||||
return "http://mock"
|
||||
func (m *MockDownloader) ListSource() config.BytesSource {
|
||||
return config.BytesSource{
|
||||
Type: config.BytesSourceTypeHttp,
|
||||
From: "http://mock-downloader",
|
||||
}
|
||||
}
|
||||
|
||||
func createTestListFile(dir string, totalLines int) (string, int) {
|
||||
|
@ -502,12 +494,3 @@ func RandStringBytes(n int) string {
|
|||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func inlineList(lines ...string) string {
|
||||
res := strings.Join(lines, "\n")
|
||||
|
||||
// ensure at least one line ending so it's parsed as an inline block
|
||||
res += "\n"
|
||||
|
||||
return res
|
||||
}
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
package lists
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
)
|
||||
|
||||
type SourceOpener interface {
|
||||
fmt.Stringer
|
||||
|
||||
Open() (io.ReadCloser, error)
|
||||
}
|
||||
|
||||
func NewSourceOpener(txtLocInfo string, source config.BytesSource, downloader FileDownloader) (SourceOpener, error) {
|
||||
switch source.Type {
|
||||
case config.BytesSourceTypeText:
|
||||
return &textOpener{source: source, locInfo: txtLocInfo}, nil
|
||||
|
||||
case config.BytesSourceTypeHttp:
|
||||
return &httpOpener{source: source, downloader: downloader}, nil
|
||||
|
||||
case config.BytesSourceTypeFile:
|
||||
return &fileOpener{source: source}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cannot open %s", source)
|
||||
}
|
||||
|
||||
type textOpener struct {
|
||||
source config.BytesSource
|
||||
locInfo string
|
||||
}
|
||||
|
||||
func (o *textOpener) Open() (io.ReadCloser, error) {
|
||||
return io.NopCloser(strings.NewReader(o.source.From)), nil
|
||||
}
|
||||
|
||||
func (o *textOpener) String() string {
|
||||
return fmt.Sprintf("%s: %s", o.locInfo, o.source)
|
||||
}
|
||||
|
||||
type httpOpener struct {
|
||||
source config.BytesSource
|
||||
downloader FileDownloader
|
||||
}
|
||||
|
||||
func (o *httpOpener) Open() (io.ReadCloser, error) {
|
||||
return o.downloader.DownloadFile(o.source.From)
|
||||
}
|
||||
|
||||
func (o *httpOpener) String() string {
|
||||
return o.source.String()
|
||||
}
|
||||
|
||||
type fileOpener struct {
|
||||
source config.BytesSource
|
||||
}
|
||||
|
||||
func (o *fileOpener) Open() (io.ReadCloser, error) {
|
||||
return os.Open(o.source.From)
|
||||
}
|
||||
|
||||
func (o *fileOpener) String() string {
|
||||
return o.source.String()
|
||||
}
|
|
@ -101,18 +101,14 @@ func NewBlockingResolver(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
refreshPeriod := cfg.RefreshPeriod.ToDuration()
|
||||
downloader := createDownloader(cfg, bootstrap)
|
||||
blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.BlackLists,
|
||||
refreshPeriod, downloader, cfg.ProcessingConcurrency,
|
||||
(cfg.StartStrategy == config.StartStrategyTypeFast), cfg.MaxErrorsPerFile)
|
||||
whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.WhiteLists,
|
||||
refreshPeriod, downloader, cfg.ProcessingConcurrency,
|
||||
(cfg.StartStrategy == config.StartStrategyTypeFast), cfg.MaxErrorsPerFile)
|
||||
downloader := lists.NewDownloader(cfg.Loading.Downloads, bootstrap.NewHTTPTransport())
|
||||
|
||||
blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.Loading, cfg.BlackLists, downloader)
|
||||
whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.Loading, cfg.WhiteLists, downloader)
|
||||
whitelistOnlyGroups := determineWhitelistOnlyGroups(&cfg)
|
||||
|
||||
err = multierror.Append(err, blErr, wlErr).ErrorOrNil()
|
||||
if err != nil && cfg.StartStrategy == config.StartStrategyTypeFailOnError {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -156,15 +152,6 @@ func NewBlockingResolver(
|
|||
return res, nil
|
||||
}
|
||||
|
||||
func createDownloader(cfg config.BlockingConfig, bootstrap *Bootstrap) *lists.HTTPDownloader {
|
||||
return lists.NewDownloader(
|
||||
lists.WithTimeout(cfg.DownloadTimeout.ToDuration()),
|
||||
lists.WithAttempts(cfg.DownloadAttempts),
|
||||
lists.WithCooldown(cfg.DownloadCooldown.ToDuration()),
|
||||
lists.WithTransport(bootstrap.NewHTTPTransport()),
|
||||
)
|
||||
}
|
||||
|
||||
func setupRedisEnabledSubscriber(c *BlockingResolver) {
|
||||
go func() {
|
||||
for em := range c.redisClient.EnabledChannel {
|
||||
|
|
|
@ -97,9 +97,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
BlackLists: map[string][]string{
|
||||
"gr1": {group1File.Path},
|
||||
"gr2": {group2File.Path},
|
||||
BlackLists: map[string][]config.BytesSource{
|
||||
"gr1": config.NewBytesSources(group1File.Path),
|
||||
"gr2": config.NewBytesSources(group2File.Path),
|
||||
},
|
||||
}
|
||||
})
|
||||
|
@ -125,9 +125,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
BlackLists: map[string][]string{
|
||||
"gr1": {group1File.Path},
|
||||
"gr2": {group2File.Path},
|
||||
BlackLists: map[string][]config.BytesSource{
|
||||
"gr1": config.NewBytesSources(group1File.Path),
|
||||
"gr2": config.NewBytesSources(group2File.Path),
|
||||
},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"gr1"},
|
||||
|
@ -164,13 +164,13 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
BlackLists: map[string][]string{
|
||||
"gr1": {"\n/regex/"},
|
||||
BlackLists: map[string][]config.BytesSource{
|
||||
"gr1": {config.TextBytesSource("/regex/")},
|
||||
},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"gr1"},
|
||||
},
|
||||
StartStrategy: config.StartStrategyTypeFast,
|
||||
Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFast},
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -193,10 +193,10 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
BlockTTL: config.Duration(6 * time.Hour),
|
||||
BlackLists: map[string][]string{
|
||||
"gr1": {group1File.Path},
|
||||
"gr2": {group2File.Path},
|
||||
"defaultGroup": {defaultGroupFile.Path},
|
||||
BlackLists: map[string][]config.BytesSource{
|
||||
"gr1": config.NewBytesSources(group1File.Path),
|
||||
"gr2": config.NewBytesSources(group2File.Path),
|
||||
"defaultGroup": config.NewBytesSources(defaultGroupFile.Path),
|
||||
},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"Client1": {"gr1"},
|
||||
|
@ -399,8 +399,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
BlackLists: map[string][]string{
|
||||
"defaultGroup": {defaultGroupFile.Path},
|
||||
BlackLists: map[string][]config.BytesSource{
|
||||
"defaultGroup": config.NewBytesSources(defaultGroupFile.Path),
|
||||
},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"defaultGroup"},
|
||||
|
@ -425,8 +425,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlackLists: map[string][]string{
|
||||
"defaultGroup": {defaultGroupFile.Path},
|
||||
BlackLists: map[string][]config.BytesSource{
|
||||
"defaultGroup": config.NewBytesSources(defaultGroupFile.Path),
|
||||
},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"defaultGroup"},
|
||||
|
@ -470,8 +470,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
BlockTTL: config.Duration(6 * time.Hour),
|
||||
BlackLists: map[string][]string{
|
||||
"defaultGroup": {defaultGroupFile.Path},
|
||||
BlackLists: map[string][]config.BytesSource{
|
||||
"defaultGroup": config.NewBytesSources(defaultGroupFile.Path),
|
||||
},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"defaultGroup"},
|
||||
|
@ -508,8 +508,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
When("BlockType is custom IP only for ipv4", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
BlackLists: map[string][]string{
|
||||
"defaultGroup": {defaultGroupFile.Path},
|
||||
BlackLists: map[string][]config.BytesSource{
|
||||
"defaultGroup": config.NewBytesSources(defaultGroupFile.Path),
|
||||
},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"defaultGroup"},
|
||||
|
@ -601,8 +601,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
BlackLists: map[string][]string{"gr1": {group1File.Path}},
|
||||
WhiteLists: map[string][]string{"gr1": {group1File.Path}},
|
||||
BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(group1File.Path)},
|
||||
WhiteLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(group1File.Path)},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"gr1"},
|
||||
},
|
||||
|
@ -627,9 +627,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "zeroIP",
|
||||
BlockTTL: config.Duration(60 * time.Second),
|
||||
WhiteLists: map[string][]string{
|
||||
"gr1": {group1File.Path},
|
||||
"gr2": {group2File.Path},
|
||||
WhiteLists: map[string][]config.BytesSource{
|
||||
"gr1": config.NewBytesSources(group1File.Path),
|
||||
"gr2": config.NewBytesSources(group2File.Path),
|
||||
},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"gr1"},
|
||||
|
@ -728,8 +728,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
BlackLists: map[string][]string{"gr1": {group1File.Path}},
|
||||
WhiteLists: map[string][]string{"gr1": {defaultGroupFile.Path}},
|
||||
BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(group1File.Path)},
|
||||
WhiteLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(defaultGroupFile.Path)},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"gr1"},
|
||||
},
|
||||
|
@ -755,7 +755,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
BlackLists: map[string][]string{"gr1": {group1File.Path}},
|
||||
BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources(group1File.Path)},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"gr1"},
|
||||
},
|
||||
|
@ -798,9 +798,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
Describe("Control status via API", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
BlackLists: map[string][]string{
|
||||
"defaultGroup": {defaultGroupFile.Path},
|
||||
"group1": {group1File.Path},
|
||||
BlackLists: map[string][]config.BytesSource{
|
||||
"defaultGroup": config.NewBytesSources(defaultGroupFile.Path),
|
||||
"group1": config.NewBytesSources(group1File.Path),
|
||||
},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"defaultGroup", "group1"},
|
||||
|
@ -1118,13 +1118,13 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
MatchError("unknown blockType 'wrong', please use one of: ZeroIP, NxDomain or specify destination IP address(es)"))
|
||||
})
|
||||
})
|
||||
When("startStrategy is failOnError", func() {
|
||||
When("strategy is failOnError", func() {
|
||||
It("should fail if lists can't be downloaded", func() {
|
||||
_, err := NewBlockingResolver(config.BlockingConfig{
|
||||
BlackLists: map[string][]string{"gr1": {"wrongPath"}},
|
||||
WhiteLists: map[string][]string{"whitelist": {"wrongPath"}},
|
||||
StartStrategy: config.StartStrategyTypeFailOnError,
|
||||
BlockType: "zeroIp",
|
||||
BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources("wrongPath")},
|
||||
WhiteLists: map[string][]config.BytesSource{"whitelist": config.NewBytesSources("wrongPath")},
|
||||
Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFailOnError},
|
||||
BlockType: "zeroIp",
|
||||
}, nil, systemResolverBootstrap)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
|
|
|
@ -4,13 +4,14 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/lists"
|
||||
"github.com/0xERR0R/blocky/lists/parsers"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/ThinkChaos/parcour"
|
||||
"github.com/ThinkChaos/parcour/jobgroup"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
@ -18,33 +19,37 @@ import (
|
|||
const (
|
||||
// reduce initial capacity so we don't waste memory if there are less entries than before
|
||||
memReleaseFactor = 2
|
||||
|
||||
producersBuffCap = 1000
|
||||
)
|
||||
|
||||
type HostsFileEntry = parsers.HostsFileEntry
|
||||
|
||||
type HostsFileResolver struct {
|
||||
configurable[*config.HostsFileConfig]
|
||||
NextResolver
|
||||
typed
|
||||
|
||||
hosts splitHostsFileData
|
||||
hosts splitHostsFileData
|
||||
downloader lists.FileDownloader
|
||||
}
|
||||
|
||||
type HostsFileEntry = parsers.HostsFileEntry
|
||||
|
||||
func NewHostsFileResolver(cfg config.HostsFileConfig) *HostsFileResolver {
|
||||
func NewHostsFileResolver(cfg config.HostsFileConfig, bootstrap *Bootstrap) (*HostsFileResolver, error) {
|
||||
r := HostsFileResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("hosts_file"),
|
||||
|
||||
downloader: lists.NewDownloader(cfg.Loading.Downloads, bootstrap.NewHTTPTransport()),
|
||||
}
|
||||
|
||||
if err := r.parseHostsFile(context.Background()); err != nil {
|
||||
r.log().Errorf("disabling hosts file resolving due to error: %s", err)
|
||||
|
||||
r.cfg.Filepath = "" // don't try parsing the file again
|
||||
} else {
|
||||
go r.periodicUpdate()
|
||||
err := cfg.Loading.StartPeriodicRefresh(r.loadSources, func(err error) {
|
||||
r.log().WithError(err).Errorf("could not load hosts files")
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &r
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
|
@ -102,7 +107,7 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp
|
|||
}
|
||||
|
||||
func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
if r.cfg.Filepath == "" {
|
||||
if !r.IsEnabled() {
|
||||
return r.next.Resolve(request)
|
||||
}
|
||||
|
||||
|
@ -144,27 +149,78 @@ func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain
|
|||
return response
|
||||
}
|
||||
|
||||
func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
|
||||
const maxErrorsPerFile = 5
|
||||
|
||||
if r.cfg.Filepath == "" {
|
||||
func (r *HostsFileResolver) loadSources(ctx context.Context) error {
|
||||
if !r.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(r.cfg.Filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
r.log().Debug("loading hosts files")
|
||||
|
||||
//nolint:ineffassign,staticcheck // keep `ctx :=` so if we use ctx in the future, we use the correct one
|
||||
consumersGrp, ctx := jobgroup.WithContext(ctx)
|
||||
defer consumersGrp.Close()
|
||||
|
||||
producersGrp := jobgroup.WithMaxConcurrency(consumersGrp, r.cfg.Loading.Concurrency)
|
||||
defer producersGrp.Close()
|
||||
|
||||
producers := parcour.NewProducersWithBuffer[*HostsFileEntry](producersGrp, consumersGrp, producersBuffCap)
|
||||
defer producers.Close()
|
||||
|
||||
for i, source := range r.cfg.Sources {
|
||||
i, source := i, source
|
||||
|
||||
producers.GoProduce(func(ctx context.Context, hostsChan chan<- *HostsFileEntry) error {
|
||||
locInfo := fmt.Sprintf("item #%d", i)
|
||||
|
||||
opener, err := lists.NewSourceOpener(locInfo, source, r.downloader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.parseFile(ctx, opener, hostsChan)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing %s: %w", opener, err) // err is parsers.ErrTooManyErrors
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
newHosts := newSplitHostsDataWithSameCapacity(r.hosts)
|
||||
|
||||
p := parsers.AllowErrors(parsers.HostsFile(f), maxErrorsPerFile)
|
||||
p.OnErr(func(err error) {
|
||||
r.log().Warnf("error parsing %s: %s, trying to continue", r.cfg.Filepath, err)
|
||||
producers.GoConsume(func(ctx context.Context, ch <-chan *HostsFileEntry) error {
|
||||
for entry := range ch {
|
||||
newHosts.add(entry)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
err = parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error {
|
||||
err := producers.Wait()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.hosts = newHosts
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *HostsFileResolver) parseFile(
|
||||
ctx context.Context, opener lists.SourceOpener, hostsChan chan<- *HostsFileEntry,
|
||||
) error {
|
||||
reader, err := opener.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
p := parsers.AllowErrors(parsers.HostsFile(reader), r.cfg.Loading.MaxErrorsPerSource)
|
||||
p.OnErr(func(err error) {
|
||||
r.log().Warnf("error parsing %s: %s, trying to continue", opener, err)
|
||||
})
|
||||
|
||||
return parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error {
|
||||
if len(entry.Interface) != 0 {
|
||||
// Ignore entries with a specific interface: we don't restrict what clients/interfaces we serve entries to,
|
||||
// so this avoids returning entries that can't be accessed by the client.
|
||||
|
@ -176,32 +232,10 @@ func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
newHosts.add(entry)
|
||||
hostsChan <- entry
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing %s: %w", r.cfg.Filepath, err) // err is parsers.ErrTooManyErrors
|
||||
}
|
||||
|
||||
r.hosts = newHosts
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *HostsFileResolver) periodicUpdate() {
|
||||
if r.cfg.RefreshPeriod.ToDuration() > 0 {
|
||||
ticker := time.NewTicker(r.cfg.RefreshPeriod.ToDuration())
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
<-ticker.C
|
||||
|
||||
r.log().WithField("file", r.cfg.Filepath).Debug("refreshing hosts file")
|
||||
|
||||
util.LogOnError("can't refresh hosts file: ", r.parseHostsFile(context.Background()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stores hosts file data for IP versions separately
|
||||
|
|
|
@ -40,15 +40,22 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
Expect(tmpFile.Error).Should(Succeed())
|
||||
|
||||
sutConfig = config.HostsFileConfig{
|
||||
Filepath: tmpFile.Path,
|
||||
Sources: config.NewBytesSources(tmpFile.Path),
|
||||
HostsTTL: config.Duration(time.Duration(TTL) * time.Second),
|
||||
RefreshPeriod: config.Duration(30 * time.Minute),
|
||||
FilterLoopback: true,
|
||||
Loading: config.SourceLoadingConfig{
|
||||
RefreshPeriod: -1,
|
||||
MaxErrorsPerSource: 5,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
sut = NewHostsFileResolver(sutConfig)
|
||||
var err error
|
||||
|
||||
sut, err = NewHostsFileResolver(sutConfig, systemResolverBootstrap)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
|
||||
sut.Next(m)
|
||||
|
@ -74,12 +81,12 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
When("Hosts file cannot be located", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.HostsFileConfig{
|
||||
Filepath: "/this/file/does/not/exist",
|
||||
Sources: config.NewBytesSources("/this/file/does/not/exist"),
|
||||
HostsTTL: config.Duration(time.Duration(TTL) * time.Second),
|
||||
}
|
||||
})
|
||||
It("should not parse any hosts", func() {
|
||||
Expect(sut.cfg.Filepath).Should(BeEmpty())
|
||||
Expect(sut.cfg.Sources).ShouldNot(BeEmpty())
|
||||
Expect(sut.hosts.v4.hosts).Should(BeEmpty())
|
||||
Expect(sut.hosts.v6.hosts).Should(BeEmpty())
|
||||
Expect(sut.hosts.v4.aliases).Should(BeEmpty())
|
||||
|
@ -99,13 +106,15 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
|
||||
When("Hosts file is not set", func() {
|
||||
BeforeEach(func() {
|
||||
sut = NewHostsFileResolver(config.HostsFileConfig{})
|
||||
sutConfig.Deprecated.Filepath = new(config.BytesSource)
|
||||
sutConfig.Sources = nil
|
||||
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
|
||||
sut.Next(m)
|
||||
})
|
||||
It("should not return an error", func() {
|
||||
err := sut.parseHostsFile(context.Background())
|
||||
err := sut.loadSources(context.Background())
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
It("should go to next resolver on query", func() {
|
||||
|
@ -156,12 +165,12 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
)
|
||||
Expect(tmpFile.Error).Should(Succeed())
|
||||
|
||||
sutConfig.Filepath = tmpFile.Path
|
||||
sutConfig.Sources = config.NewBytesSources(tmpFile.Path)
|
||||
})
|
||||
|
||||
It("should not be used", func() {
|
||||
Expect(sut).ShouldNot(BeNil())
|
||||
Expect(sut.cfg.Filepath).Should(BeEmpty())
|
||||
Expect(sut.cfg.Sources).ShouldNot(BeEmpty())
|
||||
Expect(sut.hosts.v4.hosts).Should(BeEmpty())
|
||||
Expect(sut.hosts.v6.hosts).Should(BeEmpty())
|
||||
Expect(sut.hosts.v4.aliases).Should(BeEmpty())
|
||||
|
|
|
@ -400,15 +400,17 @@ func createQueryResolver(
|
|||
parallel, pErr := resolver.NewParallelBestResolver(cfg.Upstream, bootstrap, cfg.StartVerifyUpstream)
|
||||
clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
|
||||
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
|
||||
hostsFile, hfErr := resolver.NewHostsFileResolver(cfg.HostsFile, bootstrap)
|
||||
|
||||
mErr := multierror.Append(
|
||||
err = multierror.Append(
|
||||
multierror.Prefix(blErr, "blocking resolver: "),
|
||||
multierror.Prefix(pErr, "parallel resolver: "),
|
||||
multierror.Prefix(cnErr, "client names resolver: "),
|
||||
multierror.Prefix(cuErr, "conditional upstream resolver: "),
|
||||
)
|
||||
if mErr.ErrorOrNil() != nil {
|
||||
return nil, mErr
|
||||
multierror.Prefix(hfErr, "hosts file resolver: "),
|
||||
).ErrorOrNil()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r = resolver.Chain(
|
||||
|
@ -419,7 +421,7 @@ func createQueryResolver(
|
|||
resolver.NewQueryLoggingResolver(cfg.QueryLog),
|
||||
resolver.NewMetricsResolver(cfg.Prometheus),
|
||||
resolver.NewRewriterResolver(cfg.CustomDNS.RewriterConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)),
|
||||
resolver.NewHostsFileResolver(cfg.HostsFile),
|
||||
hostsFile,
|
||||
blocking,
|
||||
resolver.NewCachingResolver(cfg.Caching, redisClient),
|
||||
resolver.NewRewriterResolver(cfg.Conditional.RewriterConfig, condUpstream),
|
||||
|
|
|
@ -122,17 +122,17 @@ var _ = BeforeSuite(func() {
|
|||
},
|
||||
},
|
||||
Blocking: config.BlockingConfig{
|
||||
BlackLists: map[string][]string{
|
||||
"ads": {
|
||||
BlackLists: map[string][]config.BytesSource{
|
||||
"ads": config.NewBytesSources(
|
||||
doubleclickFile.Path,
|
||||
bildFile.Path,
|
||||
heiseFile.Path,
|
||||
},
|
||||
"youtube": {youtubeFile.Path},
|
||||
),
|
||||
"youtube": config.NewBytesSources(youtubeFile.Path),
|
||||
},
|
||||
WhiteLists: map[string][]string{
|
||||
"ads": {heiseFile.Path},
|
||||
"whitelist": {heiseFile.Path},
|
||||
WhiteLists: map[string][]config.BytesSource{
|
||||
"ads": config.NewBytesSources(heiseFile.Path),
|
||||
"whitelist": config.NewBytesSources(heiseFile.Path),
|
||||
},
|
||||
ClientGroupsBlock: map[string][]string{
|
||||
"default": {"ads"},
|
||||
|
|
Loading…
Reference in New Issue