chore: refactoring list cache, extracting download functionality (#508)

This commit is contained in:
Dimitri Herzog 2022-05-06 17:57:33 +02:00 committed by GitHub
parent 8e472aa271
commit 53814a2208
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 489 additions and 191 deletions

View File

@ -473,7 +473,7 @@ type BlockingConfig struct {
BlockType string `yaml:"blockType" default:"ZEROIP"`
BlockTTL Duration `yaml:"blockTTL" default:"6h"`
DownloadTimeout Duration `yaml:"downloadTimeout" default:"60s"`
DownloadAttempts int `yaml:"downloadAttempts" default:"3"`
DownloadAttempts uint `yaml:"downloadAttempts" default:"3"`
DownloadCooldown Duration `yaml:"downloadCooldown" default:"1s"`
RefreshPeriod Duration `yaml:"refreshPeriod" default:"4h"`
FailStartOnListError bool `yaml:"failStartOnListError" default:"false"`

152
lists/downloader.go Normal file
View File

@ -0,0 +1,152 @@
package lists
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"time"
"github.com/0xERR0R/blocky/evt"
"github.com/avast/retry-go/v4"
)
const (
defaultDownloadTimeout = time.Second
defaultDownloadAttempts = uint(1)
defaultDownloadCooldown = 500 * time.Millisecond
)
// TransientError represents a temporary error like timeout, network errors...
type TransientError struct {
inner error
}
func (e *TransientError) Error() string {
return fmt.Sprintf("temporary error occurred: %v", e.inner)
}
func (e *TransientError) Unwrap() error {
return e.inner
}
// FileDownloader is able to download some text file
type FileDownloader interface {
DownloadFile(link string) (io.ReadCloser, error)
}
// HTTPDownloader downloads files via HTTP protocol
type HTTPDownloader struct {
downloadTimeout time.Duration
downloadAttempts uint
downloadCooldown time.Duration
httpTransport *http.Transport
}
type DownloaderOption func(c *HTTPDownloader)
func NewDownloader(options ...DownloaderOption) *HTTPDownloader {
d := &HTTPDownloader{
downloadTimeout: defaultDownloadTimeout,
downloadAttempts: defaultDownloadAttempts,
downloadCooldown: defaultDownloadCooldown,
httpTransport: &http.Transport{},
}
for _, opt := range options {
opt(d)
}
return d
}
// WithTimeout sets the download timeout
func WithTimeout(timeout time.Duration) DownloaderOption {
return func(d *HTTPDownloader) {
d.downloadTimeout = timeout
}
}
// WithTimeout sets the pause between 2 download attempts
func WithCooldown(cooldown time.Duration) DownloaderOption {
return func(d *HTTPDownloader) {
d.downloadCooldown = cooldown
}
}
// WithTimeout sets the attempt number for retry
func WithAttempts(downloadAttempts uint) DownloaderOption {
return func(d *HTTPDownloader) {
d.downloadAttempts = downloadAttempts
}
}
// WithTimeout sets the HTTP transport
func WithTransport(httpTransport *http.Transport) DownloaderOption {
return func(d *HTTPDownloader) {
d.httpTransport = httpTransport
}
}
func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) {
client := http.Client{
Timeout: d.downloadTimeout,
Transport: d.httpTransport,
}
logger().WithField("link", link).Info("starting download")
var body io.ReadCloser
err := retry.Do(
func() error {
var resp *http.Response
var httpErr error
//nolint:bodyclose
if resp, httpErr = client.Get(link); httpErr == nil {
if resp.StatusCode == http.StatusOK {
body = resp.Body
return nil
}
_ = resp.Body.Close()
return fmt.Errorf("got status code %d", resp.StatusCode)
}
var netErr net.Error
if errors.As(httpErr, &netErr) && (netErr.Timeout() || netErr.Temporary()) {
return &TransientError{inner: netErr}
}
return httpErr
},
retry.Attempts(d.downloadAttempts),
retry.DelayType(retry.FixedDelay),
retry.Delay(d.downloadCooldown),
retry.LastErrorOnly(true),
retry.OnRetry(func(n uint, err error) {
var transientErr *TransientError
var dnsErr *net.DNSError
logger := logger().WithField("link", link).WithField("attempt",
fmt.Sprintf("%d/%d", n+1, d.downloadAttempts))
switch {
case errors.As(err, &transientErr):
logger.Warnf("Temporary network err / Timeout occurred: %s", transientErr)
case errors.As(err, &dnsErr):
logger.Warnf("Name resolution err: %s", dnsErr.Err)
default:
logger.Warnf("Can't download file: %s", err)
}
onDownloadError(link)
}))
return body, err
}
func onDownloadError(link string) {
evt.Bus().Publish(evt.CachingFailedDownloadChanged, link)
}

217
lists/downloader_test.go Normal file
View File

@ -0,0 +1,217 @@
package lists
import (
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"time"
. "github.com/0xERR0R/blocky/evt"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sirupsen/logrus/hooks/test"
)
var _ = Describe("Downloader", func() {
var (
sut *HTTPDownloader
failedDownloadCountEvtChannel chan string
loggerHook *test.Hook
)
BeforeEach(func() {
failedDownloadCountEvtChannel = make(chan string, 5)
// collect received events in the channel
fn := func(url string) {
failedDownloadCountEvtChannel <- url
}
Expect(Bus().Subscribe(CachingFailedDownloadChanged, fn)).Should(Succeed())
DeferCleanup(func() {
Expect(Bus().Unsubscribe(CachingFailedDownloadChanged, fn))
})
loggerHook = test.NewGlobal()
log.Log().AddHook(loggerHook)
DeferCleanup(loggerHook.Reset)
})
Describe("Construct downloader", func() {
When("No options are provided", func() {
BeforeEach(func() {
sut = NewDownloader()
})
It("Should provide default valus", func() {
Expect(sut.downloadAttempts).Should(BeNumerically("==", defaultDownloadAttempts))
Expect(sut.downloadTimeout).Should(BeNumerically("==", defaultDownloadTimeout))
Expect(sut.downloadCooldown).Should(BeNumerically("==", defaultDownloadCooldown))
})
})
When("Options are provided", func() {
transport := &http.Transport{}
BeforeEach(func() {
sut = NewDownloader(
WithAttempts(5),
WithCooldown(2*time.Second),
WithTimeout(5*time.Second),
WithTransport(transport),
)
})
It("Should use provided parameters", func() {
Expect(sut.downloadAttempts).Should(BeNumerically("==", 5))
Expect(sut.downloadTimeout).Should(BeNumerically("==", 5*time.Second))
Expect(sut.downloadCooldown).Should(BeNumerically("==", 2*time.Second))
Expect(sut.httpTransport).Should(BeIdenticalTo(transport))
})
})
})
Describe("Download of a file", func() {
var server *httptest.Server
When("Download was successful", func() {
BeforeEach(func() {
server = TestServer("line.one\nline.two")
DeferCleanup(server.Close)
sut = NewDownloader()
})
It("Should return all lines from the file", func() {
reader, err := sut.DownloadFile(server.URL)
Expect(err).Should(Succeed())
Expect(reader).Should(Not(BeNil()))
DeferCleanup(reader.Close)
buf := new(strings.Builder)
_, err = io.Copy(buf, reader)
Expect(err).Should(Succeed())
Expect(buf.String()).Should(Equal("line.one\nline.two"))
})
})
When("Server returns NOT_FOUND (404)", func() {
BeforeEach(func() {
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusNotFound)
}))
DeferCleanup(server.Close)
sut = NewDownloader(WithAttempts(3))
})
It("Should return error", func() {
reader, err := sut.DownloadFile(server.URL)
Expect(err).Should(HaveOccurred())
Expect(reader).Should(BeNil())
Expect(err.Error()).Should(Equal("got status code 404"))
Expect(failedDownloadCountEvtChannel).Should(HaveLen(3))
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal(server.URL)))
})
})
When("Wrong URL is defined", func() {
BeforeEach(func() {
sut = NewDownloader()
})
It("Should return error", func() {
_, err := sut.DownloadFile("somewrongurl")
Expect(err).Should(HaveOccurred())
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Can't download file: "))
// failed download event was emitted only once
Expect(failedDownloadCountEvtChannel).Should(HaveLen(1))
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal("somewrongurl")))
})
})
When("If timeout occurs on first request", func() {
var attempt uint64 = 1
BeforeEach(func() {
sut = NewDownloader(
WithTimeout(20*time.Millisecond),
WithAttempts(3),
WithCooldown(time.Millisecond))
// should produce a timeout on first attempt
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
a := atomic.LoadUint64(&attempt)
atomic.AddUint64(&attempt, 1)
if a == 1 {
time.Sleep(500 * time.Millisecond)
} else {
_, err := rw.Write([]byte("blocked1.com"))
Expect(err).Should(Succeed())
}
}))
DeferCleanup(server.Close)
})
It("Should perform a retry and return file content", func() {
reader, err := sut.DownloadFile(server.URL)
Expect(err).Should(Succeed())
Expect(reader).Should(Not(BeNil()))
DeferCleanup(reader.Close)
buf := new(strings.Builder)
_, err = io.Copy(buf, reader)
Expect(err).Should(Succeed())
Expect(buf.String()).Should(Equal("blocked1.com"))
// failed download event was emitted only once
Expect(failedDownloadCountEvtChannel).Should(HaveLen(1))
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal(server.URL)))
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Temporary network err / Timeout occurred: "))
})
})
When("If timeout occurs on all request", func() {
BeforeEach(func() {
sut = NewDownloader(
WithTimeout(100*time.Millisecond),
WithAttempts(3),
WithCooldown(time.Millisecond))
// should always produce a timeout
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(200 * time.Millisecond)
}))
DeferCleanup(server.Close)
})
It("Should perform a retry until max retry attempt count is reached and return TransientError", func() {
reader, err := sut.DownloadFile(server.URL)
Expect(err).Should(HaveOccurred())
var transientErr *TransientError
Expect(errors.As(err, &transientErr)).To(BeTrue())
Expect(transientErr.Unwrap().Error()).Should(ContainSubstring("Timeout"))
Expect(reader).Should(BeNil())
// failed download event was emitted 3 times
Expect(failedDownloadCountEvtChannel).Should(HaveLen(3))
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal(server.URL)))
})
})
When("DNS resolution of passed URL fails", func() {
BeforeEach(func() {
sut = NewDownloader(
WithTimeout(100*time.Millisecond),
WithAttempts(3),
WithCooldown(time.Millisecond))
})
It("Should perform a retry until max retry attempt count is reached and return DNSError", func() {
reader, err := sut.DownloadFile("http://some.domain.which.does.not.exist")
Expect(err).Should(HaveOccurred())
var dnsError *net.DNSError
Expect(errors.As(err, &dnsError)).To(BeTrue())
Expect(reader).Should(BeNil())
// failed download event was emitted 3 times
Expect(failedDownloadCountEvtChannel).Should(HaveLen(3))
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal("http://some.domain.which.does.not.exist")))
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Name resolution err: "))
})
})
})
})

View File

@ -7,15 +7,13 @@ import (
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/0xERR0R/blocky/cache/stringcache"
"github.com/avast/retry-go/v4"
"github.com/sirupsen/logrus"
"github.com/hako/durafmt"
@ -23,7 +21,6 @@ import (
"github.com/0xERR0R/blocky/evt"
"github.com/0xERR0R/blocky/log"
"github.com/sirupsen/logrus"
)
// ListCacheType represents the type of cached list ENUM(
@ -46,13 +43,10 @@ type ListCache struct {
groupCaches map[string]stringcache.StringCache
lock sync.RWMutex
groupToLinks map[string][]string
refreshPeriod time.Duration
downloadTimeout time.Duration
downloadAttempts int
downloadCooldown time.Duration
httpTransport *http.Transport
listType ListCacheType
groupToLinks map[string][]string
refreshPeriod time.Duration
downloader FileDownloader
listType ListCacheType
}
// Configuration returns current configuration and stats
@ -80,6 +74,9 @@ func (b *ListCache) Configuration() (result []string) {
var total int
b.lock.RLock()
defer b.lock.RUnlock()
for group, cache := range b.groupCaches {
result = append(result, fmt.Sprintf(" %s: %d entries", group, cache.ElementCount()))
total += cache.ElementCount()
@ -92,19 +89,15 @@ func (b *ListCache) Configuration() (result []string) {
// NewListCache creates new list instance
func NewListCache(t ListCacheType, groupToLinks map[string][]string, refreshPeriod time.Duration,
downloadTimeout time.Duration, downloadAttempts int,
downloadCooldown time.Duration, httpTransport *http.Transport) (*ListCache, error) {
downloader FileDownloader) (*ListCache, error) {
groupCaches := make(map[string]stringcache.StringCache)
b := &ListCache{
groupToLinks: groupToLinks,
groupCaches: groupCaches,
refreshPeriod: refreshPeriod,
downloadTimeout: downloadTimeout,
downloadAttempts: downloadAttempts,
downloadCooldown: downloadCooldown,
httpTransport: httpTransport,
listType: t,
groupToLinks: groupToLinks,
groupCaches: groupCaches,
refreshPeriod: refreshPeriod,
downloader: downloader,
listType: t,
}
initError := b.refresh(true)
@ -217,12 +210,12 @@ func (b *ListCache) refresh(init bool) error {
}
}
if b.groupCaches[group] != nil {
evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, b.groupCaches[group].ElementCount())
if cacheForGroup != nil {
evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, cacheForGroup.ElementCount())
logger().WithFields(logrus.Fields{
"group": group,
"total_count": b.groupCaches[group].ElementCount(),
"total_count": cacheForGroup.ElementCount(),
}).Info("group import finished")
}
}
@ -230,61 +223,6 @@ func (b *ListCache) refresh(init bool) error {
return err
}
func (b *ListCache) downloadFile(link string) (io.ReadCloser, error) {
client := http.Client{
Timeout: b.downloadTimeout,
Transport: b.httpTransport,
}
var resp *http.Response
logger().WithField("link", link).Info("starting download")
var body io.ReadCloser
err := retry.Do(
func() error {
var err error
//nolint:bodyclose
if resp, err = client.Get(link); err == nil {
if resp.StatusCode == http.StatusOK {
body = resp.Body
return nil
}
_ = resp.Body.Close()
return fmt.Errorf("got status code %d", resp.StatusCode)
}
return err
},
retry.Attempts(uint(b.downloadAttempts)),
retry.DelayType(retry.FixedDelay),
retry.Delay(b.downloadCooldown),
retry.LastErrorOnly(true),
retry.OnRetry(func(n uint, err error) {
var netErr net.Error
var dnsErr *net.DNSError
logger := logger().WithField("link", link).WithField("attempt",
fmt.Sprintf("%d/%d", n+1, b.downloadAttempts))
switch {
case errors.As(err, &netErr) && (netErr.Timeout() || netErr.Temporary()):
logger.Warnf("Temporary network err / Timeout occurred: %s", netErr)
case errors.As(err, &dnsErr):
logger.Warnf("Name resolution err: %s", dnsErr.Err)
default:
logger.Warnf("Can't download file: %s", err)
}
evt.Bus().Publish(evt.CachingFailedDownloadChanged, link)
}))
return body, err
}
func readFile(file string) (io.ReadCloser, error) {
logger().WithField("file", file).Info("starting processing of file")
file = strings.TrimPrefix(file, "file://")
@ -307,12 +245,12 @@ func (b *ListCache) processFile(link string, ch chan<- groupCache, wg *sync.Wait
r, err = b.getLinkReader(link)
if err != nil {
logger().Warn("err during file processing: ", err)
logger().Warn("error during file processing: ", err)
result.err = multierror.Append(result.err, err)
var netErr net.Error
var transientErr *TransientError
if errors.As(err, &netErr) && (netErr.Timeout() || netErr.Temporary()) {
if errors.As(err, &transientErr) {
// put nil to indicate the temporary err
result.cache = nil
}
@ -354,7 +292,7 @@ func (b *ListCache) getLinkReader(link string) (r io.ReadCloser, err error) {
r = io.NopCloser(strings.NewReader(link))
// link is http(s) -> download it
case strings.HasPrefix(link, "http"):
r, err = b.downloadFile(link)
r, err = b.downloader.DownloadFile(link)
// probably path to a local file
default:
r, err = readFile(link)

View File

@ -1,10 +1,11 @@
package lists
import (
"net/http"
"errors"
"io"
"net/http/httptest"
"os"
"sync/atomic"
"strings"
"time"
. "github.com/0xERR0R/blocky/evt"
@ -28,7 +29,6 @@ var _ = Describe("ListCache", func() {
file1 = TempFile("blocked1.com\nblocked1a.com")
file2 = TempFile("blocked2.com")
file3 = TempFile("blocked3.com\nblocked1a.com")
})
AfterEach(func() {
_ = os.Remove(emptyFile.Name())
@ -46,7 +46,8 @@ var _ = Describe("ListCache", func() {
lists := map[string][]string{
"gr0": {emptyFile.Name()},
}
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 30*time.Second, 3, time.Second, &http.Transport{})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
Expect(err).Should(Succeed())
found, group := sut.Match("", []string{"gr0"})
Expect(found).Should(BeFalse())
@ -59,75 +60,40 @@ var _ = Describe("ListCache", func() {
lists := map[string][]string{
"gr1": {emptyFile.Name()},
}
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 30*time.Second, 3, time.Second, &http.Transport{})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
Expect(err).Should(Succeed())
found, group := sut.Match("google.com", []string{"gr1"})
Expect(found).Should(BeFalse())
Expect(group).Should(BeEmpty())
})
})
When("If timeout occurs", func() {
var attempt uint64 = 1
It("Should perform a retry", func() {
failedDownloadCount := 0
_ = Bus().SubscribeOnce(CachingFailedDownloadChanged, func(_ string) {
failedDownloadCount++
})
// should produce a timeout on first attempt
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
a := atomic.LoadUint64(&attempt)
if a == 1 {
time.Sleep(500 * time.Millisecond)
} else {
_, err := rw.Write([]byte("blocked1.com"))
Expect(err).Should(Succeed())
}
atomic.AddUint64(&attempt, 1)
}))
defer s.Close()
lists := map[string][]string{
"gr1": {s.URL},
}
sut, _ := NewListCache(
ListCacheTypeBlacklist, lists,
0, 400*time.Millisecond, 3, time.Millisecond,
&http.Transport{},
)
Eventually(func(g Gomega) {
found, group := sut.Match("blocked1.com", []string{"gr1"})
g.Expect(found).Should(BeTrue())
g.Expect(group).Should(Equal("gr1"))
}, "1s").Should(Succeed())
Expect(failedDownloadCount).Should(Equal(1))
})
})
When("a temporary err occurs on download", func() {
var attempt uint64 = 1
When("a temporary/transient err occurs on download", func() {
It("should not delete existing elements from group cache", func() {
// should produce a timeout on second attempt
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
a := atomic.LoadUint64(&attempt)
if a != 1 {
time.Sleep(200 * time.Millisecond)
} else {
_, err := rw.Write([]byte("blocked1.com"))
Expect(err).Should(Succeed())
}
atomic.AddUint64(&attempt, 1)
}))
defer s.Close()
// should produce a transient error on second and third attempt
data := make(chan (func() (io.ReadCloser, error)), 3)
mockDownloader := &MockDownloader{data: data}
data <- func() (io.ReadCloser, error) { //nolint:unparam
return io.NopCloser(strings.NewReader("blocked1.com")), nil
}
data <- func() (io.ReadCloser, error) { //nolint:unparam
return nil, &TransientError{inner: errors.New("boom")}
}
data <- func() (io.ReadCloser, error) { //nolint:unparam
return nil, &TransientError{inner: errors.New("boom")}
}
lists := map[string][]string{
"gr1": {s.URL, emptyFile.Name()},
"gr1": {"http://dummy"},
}
sut, _ := NewListCache(
sut, err := NewListCache(
ListCacheTypeBlacklist, lists,
4*time.Hour, 100*time.Millisecond, 3, time.Millisecond,
&http.Transport{},
4*time.Hour,
mockDownloader,
)
Expect(err).Should(Succeed())
By("Lists loaded without timeout", func() {
Eventually(func(g Gomega) {
found, group := sut.Match("blocked1.com", []string{"gr1"})
@ -137,6 +103,14 @@ var _ = Describe("ListCache", func() {
})
Expect(sut.refresh(true)).Should(HaveOccurred())
By("List couldn't be loaded due to timeout", func() {
found, group := sut.Match("blocked1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
})
sut.Refresh()
By("List couldn't be loaded due to timeout", func() {
@ -146,26 +120,24 @@ var _ = Describe("ListCache", func() {
})
})
})
When("err occurs on download", func() {
var attempt uint64 = 1
When("non transient err occurs on download", func() {
It("should delete existing elements from group cache", func() {
// should produce a 404 err on second attempt
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
a := atomic.LoadUint64(&attempt)
if a != 1 {
rw.WriteHeader(http.StatusNotFound)
} else {
_, err := rw.Write([]byte("blocked1.com"))
Expect(err).Should(Succeed())
}
atomic.AddUint64(&attempt, 1)
}))
defer s.Close()
data := make(chan (func() (io.ReadCloser, error)), 2)
mockDownloader := &MockDownloader{data: data}
data <- func() (io.ReadCloser, error) { //nolint:unparam
return io.NopCloser(strings.NewReader("blocked1.com")), nil
}
data <- func() (io.ReadCloser, error) {
return nil, errors.New("boom")
}
lists := map[string][]string{
"gr1": {s.URL},
"gr1": {"http://dummy"},
}
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 30*time.Second, 3, time.Millisecond, &http.Transport{})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, mockDownloader)
Expect(err).Should(Succeed())
By("Lists loaded without err", func() {
Eventually(func(g Gomega) {
found, group := sut.Match("blocked1.com", []string{"gr1"})
@ -175,7 +147,7 @@ var _ = Describe("ListCache", func() {
})
sut.Refresh()
Expect(sut.refresh(false)).Should(HaveOccurred())
By("List couldn't be loaded due to 404 err", func() {
Eventually(func() bool {
@ -192,7 +164,7 @@ var _ = Describe("ListCache", func() {
"gr2": {server3.URL},
}
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 30*time.Second, 3, time.Millisecond, &http.Transport{})
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
found, group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
Expect(found).Should(BeTrue())
@ -206,19 +178,6 @@ var _ = Describe("ListCache", func() {
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr2"))
})
It("should not match if no groups are passed", func() {
lists := map[string][]string{
"gr1": {server1.URL, server2.URL},
"gr2": {server3.URL},
"withDeadLink": {"http://wrong.host.name"},
}
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 30*time.Second, 3, time.Millisecond, &http.Transport{})
found, group := sut.Match("blocked1.com", []string{})
Expect(found).Should(BeFalse())
Expect(group).Should(BeEmpty())
})
})
When("List will be updated", func() {
It("event should be fired and contain count of elements in downloaded lists", func() {
@ -232,7 +191,8 @@ var _ = Describe("ListCache", func() {
resultCnt = cnt
})
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 30*time.Second, 3, time.Millisecond, &http.Transport{})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
Expect(err).Should(Succeed())
found, group := sut.Match("blocked1.com", []string{})
Expect(found).Should(BeFalse())
@ -247,7 +207,8 @@ var _ = Describe("ListCache", func() {
"gr2": {"file://" + file3.Name()},
}
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 0, 3, time.Millisecond, &http.Transport{})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
Expect(err).Should(Succeed())
found, group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
Expect(found).Should(BeTrue())
@ -265,16 +226,32 @@ var _ = Describe("ListCache", func() {
When("inline list content is defined", func() {
It("should match", func() {
lists := map[string][]string{
"gr1": {"inlinedomain1.com\n#some comment\n#inlinedomain2.com"},
"gr1": {"inlinedomain1.com\n#some comment\ninlinedomain2.com"},
}
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 0, 3, time.Millisecond, &http.Transport{})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
Expect(err).Should(Succeed())
found, group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
found, group = sut.Match("inlinedomain1.com", []string{"gr1"})
found, group = sut.Match("inlinedomain2.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
})
})
When("Text file can't be parsed", func() {
It("should still match already imported strings", func() {
// 2nd line is too long and will cause an error
lists := map[string][]string{
"gr1": {"inlinedomain1.com\n" + strings.Repeat("longString", 100000)},
}
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
Expect(err).Should(Succeed())
found, group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
})
@ -285,7 +262,8 @@ var _ = Describe("ListCache", func() {
"gr1": {"/^apple\\.(de|com)$/\n"},
}
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 0, 3, time.Millisecond, &http.Transport{})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader())
Expect(err).Should(Succeed())
found, group := sut.Match("apple.com", []string{"gr1"})
Expect(found).Should(BeTrue())
@ -305,19 +283,22 @@ var _ = Describe("ListCache", func() {
"gr2": {"inline\ndefinition\n"},
}
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, 0, 3, time.Millisecond, &http.Transport{})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, time.Hour, NewDownloader())
Expect(err).Should(Succeed())
c := sut.Configuration()
Expect(c).Should(ContainElement("refresh period: 1 hour"))
Expect(c).Should(HaveLen(11))
})
})
When("refresh is disabled", func() {
It("should print 'refresh disabled'", func() {
lists := map[string][]string{
"gr1": {"file1", "file2"},
"gr1": {emptyFile.Name()},
}
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, -1, 0, 3, time.Millisecond, &http.Transport{})
sut, err := NewListCache(ListCacheTypeBlacklist, lists, -1, NewDownloader())
Expect(err).Should(Succeed())
c := sut.Configuration()
Expect(c).Should(ContainElement("refresh: disabled"))
@ -325,3 +306,12 @@ var _ = Describe("ListCache", func() {
})
})
})
type MockDownloader struct {
data chan (func() (io.ReadCloser, error))
}
func (m *MockDownloader) DownloadFile(link string) (io.ReadCloser, error) {
fn := <-m.data
return fn()
}

View File

@ -93,13 +93,14 @@ func NewBlockingResolver(cfg config.BlockingConfig,
redis *redis.Client, bootstrap *Bootstrap) (r ChainedResolver, err error) {
blockHandler := createBlockHandler(cfg)
refreshPeriod := time.Duration(cfg.RefreshPeriod)
timeout := time.Duration(cfg.DownloadTimeout)
cooldown := time.Duration(cfg.DownloadCooldown)
transport := bootstrap.NewHTTPTransport()
blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.BlackLists, refreshPeriod,
timeout, cfg.DownloadAttempts, cooldown, transport)
whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.WhiteLists, refreshPeriod,
timeout, cfg.DownloadAttempts, cooldown, transport)
downloader := lists.NewDownloader(
lists.WithTimeout(time.Duration(cfg.DownloadTimeout)),
lists.WithAttempts(cfg.DownloadAttempts),
lists.WithCooldown(time.Duration(cfg.DownloadCooldown)),
lists.WithTransport(bootstrap.NewHTTPTransport()),
)
blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.BlackLists, refreshPeriod, downloader)
whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.WhiteLists, refreshPeriod, downloader)
whitelistOnlyGroups := determineWhitelistOnlyGroups(&cfg)
err = multierror.Append(err, blErr, wlErr).ErrorOrNil()