mirror of https://github.com/0xERR0R/blocky.git
chore: refactoring list cache, extracting download functionality (#508)
This commit is contained in:
parent
8e472aa271
commit
53814a2208
|
@ -473,7 +473,7 @@ type BlockingConfig struct {
|
|||
BlockType string `yaml:"blockType" default:"ZEROIP"`
|
||||
BlockTTL Duration `yaml:"blockTTL" default:"6h"`
|
||||
DownloadTimeout Duration `yaml:"downloadTimeout" default:"60s"`
|
||||
DownloadAttempts int `yaml:"downloadAttempts" default:"3"`
|
||||
DownloadAttempts uint `yaml:"downloadAttempts" default:"3"`
|
||||
DownloadCooldown Duration `yaml:"downloadCooldown" default:"1s"`
|
||||
RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"`
|
||||
FailStartOnListError bool `yaml:"failStartOnListError" default:"false"`
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
package lists
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/evt"
|
||||
"github.com/avast/retry-go/v4"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDownloadTimeout = time.Second
|
||||
defaultDownloadAttempts = uint(1)
|
||||
defaultDownloadCooldown = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// TransientError represents a temporary error like timeout, network errors...
|
||||
type TransientError struct {
|
||||
inner error
|
||||
}
|
||||
|
||||
func (e *TransientError) Error() string {
|
||||
return fmt.Sprintf("temporary error occurred: %v", e.inner)
|
||||
}
|
||||
|
||||
func (e *TransientError) Unwrap() error {
|
||||
return e.inner
|
||||
}
|
||||
|
||||
// FileDownloader is able to download some text file
|
||||
type FileDownloader interface {
|
||||
DownloadFile(link string) (io.ReadCloser, error)
|
||||
}
|
||||
|
||||
// HTTPDownloader downloads files via HTTP protocol
|
||||
type HTTPDownloader struct {
|
||||
downloadTimeout time.Duration
|
||||
downloadAttempts uint
|
||||
downloadCooldown time.Duration
|
||||
httpTransport *http.Transport
|
||||
}
|
||||
|
||||
type DownloaderOption func(c *HTTPDownloader)
|
||||
|
||||
func NewDownloader(options ...DownloaderOption) *HTTPDownloader {
|
||||
d := &HTTPDownloader{
|
||||
downloadTimeout: defaultDownloadTimeout,
|
||||
downloadAttempts: defaultDownloadAttempts,
|
||||
downloadCooldown: defaultDownloadCooldown,
|
||||
httpTransport: &http.Transport{},
|
||||
}
|
||||
|
||||
for _, opt := range options {
|
||||
opt(d)
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
// WithTimeout sets the download timeout
|
||||
func WithTimeout(timeout time.Duration) DownloaderOption {
|
||||
return func(d *HTTPDownloader) {
|
||||
d.downloadTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout sets the pause between 2 download attempts
|
||||
func WithCooldown(cooldown time.Duration) DownloaderOption {
|
||||
return func(d *HTTPDownloader) {
|
||||
d.downloadCooldown = cooldown
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout sets the attempt number for retry
|
||||
func WithAttempts(downloadAttempts uint) DownloaderOption {
|
||||
return func(d *HTTPDownloader) {
|
||||
d.downloadAttempts = downloadAttempts
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout sets the HTTP transport
|
||||
func WithTransport(httpTransport *http.Transport) DownloaderOption {
|
||||
return func(d *HTTPDownloader) {
|
||||
d.httpTransport = httpTransport
|
||||
}
|
||||
}
|
||||
|
||||
func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) {
|
||||
client := http.Client{
|
||||
Timeout: d.downloadTimeout,
|
||||
Transport: d.httpTransport,
|
||||
}
|
||||
|
||||
logger().WithField("link", link).Info("starting download")
|
||||
|
||||
var body io.ReadCloser
|
||||
|
||||
err := retry.Do(
|
||||
func() error {
|
||||
var resp *http.Response
|
||||
var httpErr error
|
||||
//nolint:bodyclose
|
||||
if resp, httpErr = client.Get(link); httpErr == nil {
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
body = resp.Body
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = resp.Body.Close()
|
||||
|
||||
return fmt.Errorf("got status code %d", resp.StatusCode)
|
||||
}
|
||||
var netErr net.Error
|
||||
if errors.As(httpErr, &netErr) && (netErr.Timeout() || netErr.Temporary()) {
|
||||
return &TransientError{inner: netErr}
|
||||
}
|
||||
return httpErr
|
||||
},
|
||||
retry.Attempts(d.downloadAttempts),
|
||||
retry.DelayType(retry.FixedDelay),
|
||||
retry.Delay(d.downloadCooldown),
|
||||
retry.LastErrorOnly(true),
|
||||
retry.OnRetry(func(n uint, err error) {
|
||||
var transientErr *TransientError
|
||||
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
logger := logger().WithField("link", link).WithField("attempt",
|
||||
fmt.Sprintf("%d/%d", n+1, d.downloadAttempts))
|
||||
|
||||
switch {
|
||||
case errors.As(err, &transientErr):
|
||||
logger.Warnf("Temporary network err / Timeout occurred: %s", transientErr)
|
||||
case errors.As(err, &dnsErr):
|
||||
logger.Warnf("Name resolution err: %s", dnsErr.Err)
|
||||
default:
|
||||
logger.Warnf("Can't download file: %s", err)
|
||||
}
|
||||
|
||||
onDownloadError(link)
|
||||
}))
|
||||
|
||||
return body, err
|
||||
}
|
||||
|
||||
func onDownloadError(link string) {
|
||||
evt.Bus().Publish(evt.CachingFailedDownloadChanged, link)
|
||||
}
|
|
@ -0,0 +1,217 @@
|
|||
package lists
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
. "github.com/0xERR0R/blocky/evt"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sirupsen/logrus/hooks/test"
|
||||
)
|
||||
|
||||
var _ = Describe("Downloader", func() {
|
||||
var (
|
||||
sut *HTTPDownloader
|
||||
failedDownloadCountEvtChannel chan string
|
||||
loggerHook *test.Hook
|
||||
)
|
||||
BeforeEach(func() {
|
||||
failedDownloadCountEvtChannel = make(chan string, 5)
|
||||
// collect received events in the channel
|
||||
fn := func(url string) {
|
||||
failedDownloadCountEvtChannel <- url
|
||||
}
|
||||
Expect(Bus().Subscribe(CachingFailedDownloadChanged, fn)).Should(Succeed())
|
||||
DeferCleanup(func() {
|
||||
Expect(Bus().Unsubscribe(CachingFailedDownloadChanged, fn))
|
||||
})
|
||||
|
||||
loggerHook = test.NewGlobal()
|
||||
log.Log().AddHook(loggerHook)
|
||||
DeferCleanup(loggerHook.Reset)
|
||||
})
|
||||
|
||||
Describe("Construct downloader", func() {
|
||||
When("No options are provided", func() {
|
||||
BeforeEach(func() {
|
||||
sut = NewDownloader()
|
||||
})
|
||||
It("Should provide default valus", func() {
|
||||
Expect(sut.downloadAttempts).Should(BeNumerically("==", defaultDownloadAttempts))
|
||||
Expect(sut.downloadTimeout).Should(BeNumerically("==", defaultDownloadTimeout))
|
||||
Expect(sut.downloadCooldown).Should(BeNumerically("==", defaultDownloadCooldown))
|
||||
})
|
||||
})
|
||||
When("Options are provided", func() {
|
||||
transport := &http.Transport{}
|
||||
BeforeEach(func() {
|
||||
sut = NewDownloader(
|
||||
WithAttempts(5),
|
||||
WithCooldown(2*time.Second),
|
||||
WithTimeout(5*time.Second),
|
||||
WithTransport(transport),
|
||||
)
|
||||
})
|
||||
It("Should use provided parameters", func() {
|
||||
Expect(sut.downloadAttempts).Should(BeNumerically("==", 5))
|
||||
Expect(sut.downloadTimeout).Should(BeNumerically("==", 5*time.Second))
|
||||
Expect(sut.downloadCooldown).Should(BeNumerically("==", 2*time.Second))
|
||||
Expect(sut.httpTransport).Should(BeIdenticalTo(transport))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Download of a file", func() {
|
||||
var server *httptest.Server
|
||||
When("Download was successful", func() {
|
||||
BeforeEach(func() {
|
||||
server = TestServer("line.one\nline.two")
|
||||
DeferCleanup(server.Close)
|
||||
|
||||
sut = NewDownloader()
|
||||
})
|
||||
It("Should return all lines from the file", func() {
|
||||
reader, err := sut.DownloadFile(server.URL)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(reader).Should(Not(BeNil()))
|
||||
DeferCleanup(reader.Close)
|
||||
buf := new(strings.Builder)
|
||||
_, err = io.Copy(buf, reader)
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(buf.String()).Should(Equal("line.one\nline.two"))
|
||||
})
|
||||
|
||||
})
|
||||
When("Server returns NOT_FOUND (404)", func() {
|
||||
BeforeEach(func() {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
DeferCleanup(server.Close)
|
||||
|
||||
sut = NewDownloader(WithAttempts(3))
|
||||
})
|
||||
It("Should return error", func() {
|
||||
reader, err := sut.DownloadFile(server.URL)
|
||||
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(reader).Should(BeNil())
|
||||
Expect(err.Error()).Should(Equal("got status code 404"))
|
||||
Expect(failedDownloadCountEvtChannel).Should(HaveLen(3))
|
||||
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal(server.URL)))
|
||||
})
|
||||
|
||||
})
|
||||
When("Wrong URL is defined", func() {
|
||||
BeforeEach(func() {
|
||||
sut = NewDownloader()
|
||||
})
|
||||
It("Should return error", func() {
|
||||
_, err := sut.DownloadFile("somewrongurl")
|
||||
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Can't download file: "))
|
||||
// failed download event was emitted only once
|
||||
Expect(failedDownloadCountEvtChannel).Should(HaveLen(1))
|
||||
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal("somewrongurl")))
|
||||
})
|
||||
})
|
||||
|
||||
When("If timeout occurs on first request", func() {
|
||||
var attempt uint64 = 1
|
||||
|
||||
BeforeEach(func() {
|
||||
sut = NewDownloader(
|
||||
WithTimeout(20*time.Millisecond),
|
||||
WithAttempts(3),
|
||||
WithCooldown(time.Millisecond))
|
||||
|
||||
// should produce a timeout on first attempt
|
||||
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
a := atomic.LoadUint64(&attempt)
|
||||
atomic.AddUint64(&attempt, 1)
|
||||
if a == 1 {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
} else {
|
||||
_, err := rw.Write([]byte("blocked1.com"))
|
||||
Expect(err).Should(Succeed())
|
||||
}
|
||||
}))
|
||||
DeferCleanup(server.Close)
|
||||
|
||||
})
|
||||
It("Should perform a retry and return file content", func() {
|
||||
reader, err := sut.DownloadFile(server.URL)
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(reader).Should(Not(BeNil()))
|
||||
DeferCleanup(reader.Close)
|
||||
|
||||
buf := new(strings.Builder)
|
||||
_, err = io.Copy(buf, reader)
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(buf.String()).Should(Equal("blocked1.com"))
|
||||
|
||||
// failed download event was emitted only once
|
||||
Expect(failedDownloadCountEvtChannel).Should(HaveLen(1))
|
||||
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal(server.URL)))
|
||||
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Temporary network err / Timeout occurred: "))
|
||||
})
|
||||
})
|
||||
When("If timeout occurs on all request", func() {
|
||||
BeforeEach(func() {
|
||||
sut = NewDownloader(
|
||||
WithTimeout(100*time.Millisecond),
|
||||
WithAttempts(3),
|
||||
WithCooldown(time.Millisecond))
|
||||
|
||||
// should always produce a timeout
|
||||
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}))
|
||||
DeferCleanup(server.Close)
|
||||
})
|
||||
It("Should perform a retry until max retry attempt count is reached and return TransientError", func() {
|
||||
reader, err := sut.DownloadFile(server.URL)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
var transientErr *TransientError
|
||||
Expect(errors.As(err, &transientErr)).To(BeTrue())
|
||||
Expect(transientErr.Unwrap().Error()).Should(ContainSubstring("Timeout"))
|
||||
Expect(reader).Should(BeNil())
|
||||
|
||||
// failed download event was emitted 3 times
|
||||
Expect(failedDownloadCountEvtChannel).Should(HaveLen(3))
|
||||
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal(server.URL)))
|
||||
})
|
||||
})
|
||||
When("DNS resolution of passed URL fails", func() {
|
||||
BeforeEach(func() {
|
||||
sut = NewDownloader(
|
||||
WithTimeout(100*time.Millisecond),
|
||||
WithAttempts(3),
|
||||
WithCooldown(time.Millisecond))
|
||||
})
|
||||
It("Should perform a retry until max retry attempt count is reached and return DNSError", func() {
|
||||
reader, err := sut.DownloadFile("http://some.domain.which.does.not.exist")
|
||||
Expect(err).Should(HaveOccurred())
|
||||
var dnsError *net.DNSError
|
||||
Expect(errors.As(err, &dnsError)).To(BeTrue())
|
||||
Expect(reader).Should(BeNil())
|
||||
|
||||
// failed download event was emitted 3 times
|
||||
Expect(failedDownloadCountEvtChannel).Should(HaveLen(3))
|
||||
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal("http://some.domain.which.does.not.exist")))
|
||||
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Name resolution err: "))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -7,15 +7,13 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/cache/stringcache"
|
||||
|
||||
"github.com/avast/retry-go/v4"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/hako/durafmt"
|
||||
|
||||
|
@ -23,7 +21,6 @@ import (
|
|||
|
||||
"github.com/0xERR0R/blocky/evt"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ListCacheType represents the type of cached list ENUM(
|
||||
|
@ -46,13 +43,10 @@ type ListCache struct {
|
|||
groupCaches map[string]stringcache.StringCache
|
||||
lock sync.RWMutex
|
||||
|
||||
groupToLinks map[string][]string
|
||||
refreshPeriod time.Duration
|
||||
downloadTimeout time.Duration
|
||||
downloadAttempts int
|
||||
downloadCooldown time.Duration
|
||||
httpTransport *http.Transport
|
||||
listType ListCacheType
|
||||
groupToLinks map[string][]string
|
||||
refreshPeriod time.Duration
|
||||
downloader FileDownloader
|
||||
listType ListCacheType
|
||||
}
|
||||
|
||||
// Configuration returns current configuration and stats
|
||||
|
@ -80,6 +74,9 @@ func (b *ListCache) Configuration() (result []string) {
|
|||
|
||||
var total int
|
||||
|
||||
b.lock.RLock()
|
||||
defer b.lock.RUnlock()
|
||||
|
||||
for group, cache := range b.groupCaches {
|
||||
result = append(result, fmt.Sprintf(" %s: %d entries", group, cache.ElementCount()))
|
||||
total += cache.ElementCount()
|
||||
|
@ -92,19 +89,15 @@ func (b *ListCache) Configuration() (result []string) {
|
|||
|
||||
// NewListCache creates new list instance
|
||||
func NewListCache(t ListCacheType, groupToLinks map[string][]string, refreshPeriod time.Duration,
|
||||
downloadTimeout time.Duration, downloadAttempts int,
|
||||
downloadCooldown time.Duration, httpTransport *http.Transport) (*ListCache, error) {
|
||||
downloader FileDownloader) (*ListCache, error) {
|
||||
groupCaches := make(map[string]stringcache.StringCache)
|
||||
|
||||
b := &ListCache{
|
||||
groupToLinks: groupToLinks,
|
||||
groupCaches: groupCaches,
|
||||
refreshPeriod: refreshPeriod,
|
||||
downloadTimeout: downloadTimeout,
|
||||
downloadAttempts: downloadAttempts,
|
||||
downloadCooldown: downloadCooldown,
|
||||
httpTransport: httpTransport,
|
||||
listType: t,
|
||||
groupToLinks: groupToLinks,
|
||||
groupCaches: groupCaches,
|
||||
refreshPeriod: refreshPeriod,
|
||||
downloader: downloader,
|
||||
listType: t,
|
||||
}
|
||||
initError := b.refresh(true)
|
||||
|
||||
|
@ -217,12 +210,12 @@ func (b *ListCache) refresh(init bool) error {
|
|||
}
|
||||
}
|
||||
|
||||
if b.groupCaches[group] != nil {
|
||||
evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, b.groupCaches[group].ElementCount())
|
||||
if cacheForGroup != nil {
|
||||
evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, cacheForGroup.ElementCount())
|
||||
|
||||
logger().WithFields(logrus.Fields{
|
||||
"group": group,
|
||||
"total_count": b.groupCaches[group].ElementCount(),
|
||||
"total_count": cacheForGroup.ElementCount(),
|
||||
}).Info("group import finished")
|
||||
}
|
||||
}
|
||||
|
@ -230,61 +223,6 @@ func (b *ListCache) refresh(init bool) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (b *ListCache) downloadFile(link string) (io.ReadCloser, error) {
|
||||
client := http.Client{
|
||||
Timeout: b.downloadTimeout,
|
||||
Transport: b.httpTransport,
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
|
||||
logger().WithField("link", link).Info("starting download")
|
||||
|
||||
var body io.ReadCloser
|
||||
|
||||
err := retry.Do(
|
||||
func() error {
|
||||
var err error
|
||||
//nolint:bodyclose
|
||||
if resp, err = client.Get(link); err == nil {
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
body = resp.Body
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = resp.Body.Close()
|
||||
|
||||
return fmt.Errorf("got status code %d", resp.StatusCode)
|
||||
}
|
||||
return err
|
||||
},
|
||||
retry.Attempts(uint(b.downloadAttempts)),
|
||||
retry.DelayType(retry.FixedDelay),
|
||||
retry.Delay(b.downloadCooldown),
|
||||
retry.LastErrorOnly(true),
|
||||
retry.OnRetry(func(n uint, err error) {
|
||||
var netErr net.Error
|
||||
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
logger := logger().WithField("link", link).WithField("attempt",
|
||||
fmt.Sprintf("%d/%d", n+1, b.downloadAttempts))
|
||||
|
||||
switch {
|
||||
case errors.As(err, &netErr) && (netErr.Timeout() || netErr.Temporary()):
|
||||
logger.Warnf("Temporary network err / Timeout occurred: %s", netErr)
|
||||
case errors.As(err, &dnsErr):
|
||||
logger.Warnf("Name resolution err: %s", dnsErr.Err)
|
||||
default:
|
||||
logger.Warnf("Can't download file: %s", err)
|
||||
}
|
||||
|
||||
evt.Bus().Publish(evt.CachingFailedDownloadChanged, link)
|
||||
}))
|
||||
|
||||
return body, err
|
||||
}
|
||||
|
||||
func readFile(file string) (io.ReadCloser, error) {
|
||||
logger().WithField("file", file).Info("starting processing of file")
|
||||
file = strings.TrimPrefix(file, "file://")
|
||||
|
@ -307,12 +245,12 @@ func (b *ListCache) processFile(link string, ch chan<- groupCache, wg *sync.Wait
|
|||
r, err = b.getLinkReader(link)
|
||||
|
||||
if err != nil {
|
||||
logger().Warn("err during file processing: ", err)
|
||||
logger().Warn("error during file processing: ", err)
|
||||
result.err = multierror.Append(result.err, err)
|
||||
|
||||
var netErr net.Error
|
||||
var transientErr *TransientError
|
||||
|
||||
if errors.As(err, &netErr) && (netErr.Timeout() || netErr.Temporary()) {
|
||||
if errors.As(err, &transientErr) {
|
||||
// put nil to indicate the temporary err
|
||||
result.cache = nil
|
||||
}
|
||||
|
@ -354,7 +292,7 @@ func (b *ListCache) getLinkReader(link string) (r io.ReadCloser, err error) {
|
|||
r = io.NopCloser(strings.NewReader(link))
|
||||
// link is http(s) -> download it
|
||||
case strings.HasPrefix(link, "http"):
|
||||
r, err = b.downloadFile(link)
|
||||
r, err = b.downloader.DownloadFile(link)
|
||||
// probably path to a local file
|
||||
default:
|
||||
r, err = readFile(link)
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
package lists
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/0xERR0R/blocky/evt"
|
||||
|
@ -28,7 +29,6 @@ var _ = Describe("ListCache", func() {
|
|||
file1 = TempFile("blocked1.com\nblocked1a.com")
|
||||
file2 = TempFile("blocked2.com")
|
||||
file3 = TempFile("blocked3.com\nblocked1a.com")
|
||||
|
||||
})
|
||||
AfterEach(func() {
|
||||
_ = os.Remove(emptyFile.Name())
|
||||
|
@ -46,7 +46,8 @@ var _ = Describe("ListCache", func() {
|
|||
lists := map[string][]string{
|
||||
"gr0": {emptyFile.Name()},
|
||||
}
|
||||
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 30*time.Second, 3, time.Second, &http.Transport{})
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
found, group := sut.Match("", []string{"gr0"})
|
||||
Expect(found).Should(BeFalse())
|
||||
|
@ -59,75 +60,40 @@ var _ = Describe("ListCache", func() {
|
|||
lists := map[string][]string{
|
||||
"gr1": {emptyFile.Name()},
|
||||
}
|
||||
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 30*time.Second, 3, time.Second, &http.Transport{})
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
found, group := sut.Match("google.com", []string{"gr1"})
|
||||
Expect(found).Should(BeFalse())
|
||||
Expect(group).Should(BeEmpty())
|
||||
|
||||
})
|
||||
})
|
||||
When("If timeout occurs", func() {
|
||||
var attempt uint64 = 1
|
||||
It("Should perform a retry", func() {
|
||||
failedDownloadCount := 0
|
||||
_ = Bus().SubscribeOnce(CachingFailedDownloadChanged, func(_ string) {
|
||||
failedDownloadCount++
|
||||
})
|
||||
|
||||
// should produce a timeout on first attempt
|
||||
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
a := atomic.LoadUint64(&attempt)
|
||||
if a == 1 {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
} else {
|
||||
_, err := rw.Write([]byte("blocked1.com"))
|
||||
Expect(err).Should(Succeed())
|
||||
}
|
||||
atomic.AddUint64(&attempt, 1)
|
||||
}))
|
||||
defer s.Close()
|
||||
lists := map[string][]string{
|
||||
"gr1": {s.URL},
|
||||
}
|
||||
|
||||
sut, _ := NewListCache(
|
||||
ListCacheTypeBlacklist, lists,
|
||||
0, 400*time.Millisecond, 3, time.Millisecond,
|
||||
&http.Transport{},
|
||||
)
|
||||
Eventually(func(g Gomega) {
|
||||
found, group := sut.Match("blocked1.com", []string{"gr1"})
|
||||
g.Expect(found).Should(BeTrue())
|
||||
g.Expect(group).Should(Equal("gr1"))
|
||||
}, "1s").Should(Succeed())
|
||||
|
||||
Expect(failedDownloadCount).Should(Equal(1))
|
||||
})
|
||||
})
|
||||
When("a temporary err occurs on download", func() {
|
||||
var attempt uint64 = 1
|
||||
When("a temporary/transient err occurs on download", func() {
|
||||
It("should not delete existing elements from group cache", func() {
|
||||
// should produce a timeout on second attempt
|
||||
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
a := atomic.LoadUint64(&attempt)
|
||||
if a != 1 {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
} else {
|
||||
_, err := rw.Write([]byte("blocked1.com"))
|
||||
Expect(err).Should(Succeed())
|
||||
}
|
||||
atomic.AddUint64(&attempt, 1)
|
||||
}))
|
||||
defer s.Close()
|
||||
// should produce a transient error on second and third attempt
|
||||
data := make(chan (func() (io.ReadCloser, error)), 3)
|
||||
mockDownloader := &MockDownloader{data: data}
|
||||
data <- func() (io.ReadCloser, error) { //nolint:unparam
|
||||
return io.NopCloser(strings.NewReader("blocked1.com")), nil
|
||||
}
|
||||
data <- func() (io.ReadCloser, error) { //nolint:unparam
|
||||
return nil, &TransientError{inner: errors.New("boom")}
|
||||
}
|
||||
data <- func() (io.ReadCloser, error) { //nolint:unparam
|
||||
return nil, &TransientError{inner: errors.New("boom")}
|
||||
}
|
||||
lists := map[string][]string{
|
||||
"gr1": {s.URL, emptyFile.Name()},
|
||||
"gr1": {"http://dummy"},
|
||||
}
|
||||
|
||||
sut, _ := NewListCache(
|
||||
sut, err := NewListCache(
|
||||
ListCacheTypeBlacklist, lists,
|
||||
4*time.Hour, 100*time.Millisecond, 3, time.Millisecond,
|
||||
&http.Transport{},
|
||||
4*time.Hour,
|
||||
mockDownloader,
|
||||
)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
By("Lists loaded without timeout", func() {
|
||||
Eventually(func(g Gomega) {
|
||||
found, group := sut.Match("blocked1.com", []string{"gr1"})
|
||||
|
@ -137,6 +103,14 @@ var _ = Describe("ListCache", func() {
|
|||
|
||||
})
|
||||
|
||||
Expect(sut.refresh(true)).Should(HaveOccurred())
|
||||
|
||||
By("List couldn't be loaded due to timeout", func() {
|
||||
found, group := sut.Match("blocked1.com", []string{"gr1"})
|
||||
Expect(found).Should(BeTrue())
|
||||
Expect(group).Should(Equal("gr1"))
|
||||
})
|
||||
|
||||
sut.Refresh()
|
||||
|
||||
By("List couldn't be loaded due to timeout", func() {
|
||||
|
@ -146,26 +120,24 @@ var _ = Describe("ListCache", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
When("err occurs on download", func() {
|
||||
var attempt uint64 = 1
|
||||
When("non transient err occurs on download", func() {
|
||||
It("should delete existing elements from group cache", func() {
|
||||
// should produce a 404 err on second attempt
|
||||
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
a := atomic.LoadUint64(&attempt)
|
||||
if a != 1 {
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
} else {
|
||||
_, err := rw.Write([]byte("blocked1.com"))
|
||||
Expect(err).Should(Succeed())
|
||||
}
|
||||
atomic.AddUint64(&attempt, 1)
|
||||
}))
|
||||
defer s.Close()
|
||||
data := make(chan (func() (io.ReadCloser, error)), 2)
|
||||
mockDownloader := &MockDownloader{data: data}
|
||||
data <- func() (io.ReadCloser, error) { //nolint:unparam
|
||||
return io.NopCloser(strings.NewReader("blocked1.com")), nil
|
||||
}
|
||||
data <- func() (io.ReadCloser, error) {
|
||||
return nil, errors.New("boom")
|
||||
}
|
||||
lists := map[string][]string{
|
||||
"gr1": {s.URL},
|
||||
"gr1": {"http://dummy"},
|
||||
}
|
||||
|
||||
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 30*time.Second, 3, time.Millisecond, &http.Transport{})
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, mockDownloader)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
By("Lists loaded without err", func() {
|
||||
Eventually(func(g Gomega) {
|
||||
found, group := sut.Match("blocked1.com", []string{"gr1"})
|
||||
|
@ -175,7 +147,7 @@ var _ = Describe("ListCache", func() {
|
|||
|
||||
})
|
||||
|
||||
sut.Refresh()
|
||||
Expect(sut.refresh(false)).Should(HaveOccurred())
|
||||
|
||||
By("List couldn't be loaded due to 404 err", func() {
|
||||
Eventually(func() bool {
|
||||
|
@ -192,7 +164,7 @@ var _ = Describe("ListCache", func() {
|
|||
"gr2": {server3.URL},
|
||||
}
|
||||
|
||||
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 30*time.Second, 3, time.Millisecond, &http.Transport{})
|
||||
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
|
||||
|
||||
found, group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
|
||||
Expect(found).Should(BeTrue())
|
||||
|
@ -206,19 +178,6 @@ var _ = Describe("ListCache", func() {
|
|||
Expect(found).Should(BeTrue())
|
||||
Expect(group).Should(Equal("gr2"))
|
||||
})
|
||||
It("should not match if no groups are passed", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {server1.URL, server2.URL},
|
||||
"gr2": {server3.URL},
|
||||
"withDeadLink": {"http://wrong.host.name"},
|
||||
}
|
||||
|
||||
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 30*time.Second, 3, time.Millisecond, &http.Transport{})
|
||||
|
||||
found, group := sut.Match("blocked1.com", []string{})
|
||||
Expect(found).Should(BeFalse())
|
||||
Expect(group).Should(BeEmpty())
|
||||
})
|
||||
})
|
||||
When("List will be updated", func() {
|
||||
It("event should be fired and contain count of elements in downloaded lists", func() {
|
||||
|
@ -232,7 +191,8 @@ var _ = Describe("ListCache", func() {
|
|||
resultCnt = cnt
|
||||
})
|
||||
|
||||
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 30*time.Second, 3, time.Millisecond, &http.Transport{})
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
found, group := sut.Match("blocked1.com", []string{})
|
||||
Expect(found).Should(BeFalse())
|
||||
|
@ -247,7 +207,8 @@ var _ = Describe("ListCache", func() {
|
|||
"gr2": {"file://" + file3.Name()},
|
||||
}
|
||||
|
||||
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 0, 3, time.Millisecond, &http.Transport{})
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
found, group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
|
||||
Expect(found).Should(BeTrue())
|
||||
|
@ -265,16 +226,32 @@ var _ = Describe("ListCache", func() {
|
|||
When("inline list content is defined", func() {
|
||||
It("should match", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {"inlinedomain1.com\n#some comment\n#inlinedomain2.com"},
|
||||
"gr1": {"inlinedomain1.com\n#some comment\ninlinedomain2.com"},
|
||||
}
|
||||
|
||||
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 0, 3, time.Millisecond, &http.Transport{})
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
found, group := sut.Match("inlinedomain1.com", []string{"gr1"})
|
||||
Expect(found).Should(BeTrue())
|
||||
Expect(group).Should(Equal("gr1"))
|
||||
|
||||
found, group = sut.Match("inlinedomain1.com", []string{"gr1"})
|
||||
found, group = sut.Match("inlinedomain2.com", []string{"gr1"})
|
||||
Expect(found).Should(BeTrue())
|
||||
Expect(group).Should(Equal("gr1"))
|
||||
})
|
||||
})
|
||||
When("Text file can't be parsed", func() {
|
||||
It("should still match already imported strings", func() {
|
||||
// 2nd line is too long and will cause an error
|
||||
lists := map[string][]string{
|
||||
"gr1": {"inlinedomain1.com\n" + strings.Repeat("longString", 100000)},
|
||||
}
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
found, group := sut.Match("inlinedomain1.com", []string{"gr1"})
|
||||
Expect(found).Should(BeTrue())
|
||||
Expect(group).Should(Equal("gr1"))
|
||||
})
|
||||
|
@ -285,7 +262,8 @@ var _ = Describe("ListCache", func() {
|
|||
"gr1": {"/^apple\\.(de|com)$/\n"},
|
||||
}
|
||||
|
||||
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 0, 3, time.Millisecond, &http.Transport{})
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
found, group := sut.Match("apple.com", []string{"gr1"})
|
||||
Expect(found).Should(BeTrue())
|
||||
|
@ -305,19 +283,22 @@ var _ = Describe("ListCache", func() {
|
|||
"gr2": {"inline\ndefinition\n"},
|
||||
}
|
||||
|
||||
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 0, 3, time.Millisecond, &http.Transport{})
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, time.Hour, NewDownloader())
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement("refresh period: 1 hour"))
|
||||
Expect(c).Should(HaveLen(11))
|
||||
})
|
||||
})
|
||||
When("refresh is disabled", func() {
|
||||
It("should print 'refresh disabled'", func() {
|
||||
lists := map[string][]string{
|
||||
"gr1": {"file1", "file2"},
|
||||
"gr1": {emptyFile.Name()},
|
||||
}
|
||||
|
||||
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, -1, 0, 3, time.Millisecond, &http.Transport{})
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, lists, -1, NewDownloader())
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
c := sut.Configuration()
|
||||
Expect(c).Should(ContainElement("refresh: disabled"))
|
||||
|
@ -325,3 +306,12 @@ var _ = Describe("ListCache", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
type MockDownloader struct {
|
||||
data chan (func() (io.ReadCloser, error))
|
||||
}
|
||||
|
||||
func (m *MockDownloader) DownloadFile(link string) (io.ReadCloser, error) {
|
||||
fn := <-m.data
|
||||
return fn()
|
||||
}
|
||||
|
|
|
@ -93,13 +93,14 @@ func NewBlockingResolver(cfg config.BlockingConfig,
|
|||
redis *redis.Client, bootstrap *Bootstrap) (r ChainedResolver, err error) {
|
||||
blockHandler := createBlockHandler(cfg)
|
||||
refreshPeriod := time.Duration(cfg.RefreshPeriod)
|
||||
timeout := time.Duration(cfg.DownloadTimeout)
|
||||
cooldown := time.Duration(cfg.DownloadCooldown)
|
||||
transport := bootstrap.NewHTTPTransport()
|
||||
blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.BlackLists, refreshPeriod,
|
||||
timeout, cfg.DownloadAttempts, cooldown, transport)
|
||||
whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.WhiteLists, refreshPeriod,
|
||||
timeout, cfg.DownloadAttempts, cooldown, transport)
|
||||
downloader := lists.NewDownloader(
|
||||
lists.WithTimeout(time.Duration(cfg.DownloadTimeout)),
|
||||
lists.WithAttempts(cfg.DownloadAttempts),
|
||||
lists.WithCooldown(time.Duration(cfg.DownloadCooldown)),
|
||||
lists.WithTransport(bootstrap.NewHTTPTransport()),
|
||||
)
|
||||
blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.BlackLists, refreshPeriod, downloader)
|
||||
whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.WhiteLists, refreshPeriod, downloader)
|
||||
whitelistOnlyGroups := determineWhitelistOnlyGroups(&cfg)
|
||||
|
||||
err = multierror.Append(err, blErr, wlErr).ErrorOrNil()
|
||||
|
|
Loading…
Reference in New Issue