refactor: use groupedCache to optimize cache access (#944)

* refactor: use groupedCache to optimize cache access

* refactor: fix review findings
This commit is contained in:
Dimitri Herzog 2023-03-27 13:23:01 +02:00 committed by GitHub
parent 8757dea992
commit 3b9fd7bafe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 585 additions and 289 deletions

View File

@ -0,0 +1,78 @@
package stringcache
import (
"sort"
"golang.org/x/exp/maps"
)
type ChainedGroupedCache struct {
caches []GroupedStringCache
}
func NewChainedGroupedCache(caches ...GroupedStringCache) *ChainedGroupedCache {
return &ChainedGroupedCache{
caches: caches,
}
}
func (c *ChainedGroupedCache) ElementCount(group string) int {
sum := 0
for _, cache := range c.caches {
sum += cache.ElementCount(group)
}
return sum
}
func (c *ChainedGroupedCache) Contains(searchString string, groups []string) []string {
groupMatchedMap := make(map[string]struct{}, len(groups))
for _, cache := range c.caches {
for _, group := range cache.Contains(searchString, groups) {
groupMatchedMap[group] = struct{}{}
}
}
matchedGroups := maps.Keys(groupMatchedMap)
sort.Strings(matchedGroups)
return matchedGroups
}
func (c *ChainedGroupedCache) Refresh(group string) GroupFactory {
cacheFactories := make([]GroupFactory, len(c.caches))
for i, cache := range c.caches {
cacheFactories[i] = cache.Refresh(group)
}
return &chainedGroupFactory{
cacheFactories: cacheFactories,
}
}
type chainedGroupFactory struct {
cacheFactories []GroupFactory
}
func (c *chainedGroupFactory) AddEntry(entry string) {
for _, factory := range c.cacheFactories {
factory.AddEntry(entry)
}
}
func (c *chainedGroupFactory) Count() int {
var cnt int
for _, factory := range c.cacheFactories {
cnt += factory.Count()
}
return cnt
}
func (c *chainedGroupFactory) Finish() {
for _, factory := range c.cacheFactories {
factory.Finish()
}
}

View File

@ -0,0 +1,94 @@
package stringcache_test
import (
"github.com/0xERR0R/blocky/cache/stringcache"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Chained grouped cache", func() {
Describe("Empty cache", func() {
When("empty cache was created", func() {
cache := stringcache.NewChainedGroupedCache()
It("should have element count of 0", func() {
Expect(cache.ElementCount("someGroup")).Should(BeNumerically("==", 0))
})
It("should not find any string", func() {
Expect(cache.Contains("searchString", []string{"someGroup"})).Should(BeEmpty())
})
})
})
Describe("Delegation", func() {
When("Chained cache contains delegates", func() {
inMemoryCache1 := stringcache.NewInMemoryGroupedStringCache()
inMemoryCache2 := stringcache.NewInMemoryGroupedStringCache()
cache := stringcache.NewChainedGroupedCache(inMemoryCache1, inMemoryCache2)
factory := cache.Refresh("group1")
factory.AddEntry("string1")
factory.AddEntry("string2")
It("cache should still have 0 element, since finish was not executed", func() {
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 0))
})
It("factory has 4 elements (both caches)", func() {
Expect(factory.Count()).Should(BeNumerically("==", 4))
})
It("should have element count of 4", func() {
factory.Finish()
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 4))
})
It("should find strings", func() {
Expect(cache.Contains("string1", []string{"group1"})).Should(ConsistOf("group1"))
Expect(cache.Contains("string2", []string{"group1", "someOtherGroup"})).Should(ConsistOf("group1"))
})
})
})
Describe("Cache refresh", func() {
When("cache with 2 groups was created", func() {
inMemoryCache1 := stringcache.NewInMemoryGroupedStringCache()
inMemoryCache2 := stringcache.NewInMemoryGroupedStringCache()
cache := stringcache.NewChainedGroupedCache(inMemoryCache1, inMemoryCache2)
factory := cache.Refresh("group1")
factory.AddEntry("g1")
factory.AddEntry("both")
factory.Finish()
factory = cache.Refresh("group2")
factory.AddEntry("g2")
factory.AddEntry("both")
factory.Finish()
It("should contain 4 elements in 2 groups", func() {
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 4))
Expect(cache.ElementCount("group2")).Should(BeNumerically("==", 4))
Expect(cache.Contains("g1", []string{"group1", "group2"})).Should(ConsistOf("group1"))
Expect(cache.Contains("g2", []string{"group1", "group2"})).Should(ConsistOf("group2"))
Expect(cache.Contains("both", []string{"group1", "group2"})).Should(ConsistOf("group1", "group2"))
})
It("Should replace group content on refresh", func() {
factory := cache.Refresh("group1")
factory.AddEntry("newString")
factory.Finish()
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 2))
Expect(cache.ElementCount("group2")).Should(BeNumerically("==", 4))
Expect(cache.Contains("g1", []string{"group1", "group2"})).Should(BeEmpty())
Expect(cache.Contains("newString", []string{"group1", "group2"})).Should(ConsistOf("group1"))
Expect(cache.Contains("g2", []string{"group1", "group2"})).Should(ConsistOf("group2"))
Expect(cache.Contains("both", []string{"group1", "group2"})).Should(ConsistOf("group2"))
})
})
})
})

View File

@ -0,0 +1,25 @@
package stringcache
type GroupedStringCache interface {
// Contains checks if one or more groups in the cache contains the search string.
// Returns group(s) containing the string or empty slice if string was not found
Contains(searchString string, groups []string) []string
// Refresh creates new factory for the group to be refreshed.
// Calling Finish on the factory will perform the group refresh.
Refresh(group string) GroupFactory
// ElementCount returns the amount of elements in the group
ElementCount(group string) int
}
type GroupFactory interface {
// AddEntry adds a new string to the factory to be added later to the cache groups.
AddEntry(entry string)
// Count returns amount of processed string in the factory
Count() int
// Finish replaces the group in cache with factory's content
Finish()
}

View File

@ -0,0 +1,82 @@
package stringcache
import "sync"
type stringCacheFactoryFn func() cacheFactory
type InMemoryGroupedCache struct {
caches map[string]stringCache
lock sync.RWMutex
factoryFn stringCacheFactoryFn
}
func NewInMemoryGroupedStringCache() *InMemoryGroupedCache {
return &InMemoryGroupedCache{
caches: make(map[string]stringCache),
factoryFn: newStringCacheFactory,
}
}
func NewInMemoryGroupedRegexCache() *InMemoryGroupedCache {
return &InMemoryGroupedCache{
caches: make(map[string]stringCache),
factoryFn: newRegexCacheFactory,
}
}
func (c *InMemoryGroupedCache) ElementCount(group string) int {
c.lock.RLock()
cache, found := c.caches[group]
c.lock.RUnlock()
if !found {
return 0
}
return cache.elementCount()
}
func (c *InMemoryGroupedCache) Contains(searchString string, groups []string) []string {
var result []string
for _, group := range groups {
c.lock.RLock()
cache, found := c.caches[group]
c.lock.RUnlock()
if found && cache.contains(searchString) {
result = append(result, group)
}
}
return result
}
func (c *InMemoryGroupedCache) Refresh(group string) GroupFactory {
return &inMemoryGroupFactory{
factory: c.factoryFn(),
finishFn: func(sc stringCache) {
c.lock.Lock()
c.caches[group] = sc
c.lock.Unlock()
},
}
}
type inMemoryGroupFactory struct {
factory cacheFactory
finishFn func(stringCache)
}
func (c *inMemoryGroupFactory) AddEntry(entry string) {
c.factory.addEntry(entry)
}
func (c *inMemoryGroupFactory) Count() int {
return c.factory.count()
}
func (c *inMemoryGroupFactory) Finish() {
sc := c.factory.create()
c.finishFn(sc)
}

View File

@ -0,0 +1,132 @@
package stringcache_test
import (
"github.com/0xERR0R/blocky/cache/stringcache"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("In-Memory grouped cache", func() {
Describe("Empty cache", func() {
When("empty cache was created", func() {
cache := stringcache.NewInMemoryGroupedStringCache()
It("should have element count of 0", func() {
Expect(cache.ElementCount("someGroup")).Should(BeNumerically("==", 0))
})
It("should not find any string", func() {
Expect(cache.Contains("searchString", []string{"someGroup"})).Should(BeEmpty())
})
})
When("cache with one empty group", func() {
cache := stringcache.NewInMemoryGroupedStringCache()
factory := cache.Refresh("group1")
factory.Finish()
It("should have element count of 0", func() {
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 0))
})
It("should not find any string", func() {
Expect(cache.Contains("searchString", []string{"group1"})).Should(BeEmpty())
})
})
})
Describe("Cache creation", func() {
When("cache with 1 group was created", func() {
cache := stringcache.NewInMemoryGroupedStringCache()
factory := cache.Refresh("group1")
factory.AddEntry("string1")
factory.AddEntry("string2")
It("cache should still have 0 element, since finish was not executed", func() {
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 0))
})
It("factory has 2 elements", func() {
Expect(factory.Count()).Should(BeNumerically("==", 2))
})
It("should have element count of 2", func() {
factory.Finish()
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 2))
})
It("should find strings", func() {
Expect(cache.Contains("string1", []string{"group1"})).Should(ConsistOf("group1"))
Expect(cache.Contains("string2", []string{"group1", "someOtherGroup"})).Should(ConsistOf("group1"))
})
})
When("String grouped cache is used", func() {
cache := stringcache.NewInMemoryGroupedStringCache()
factory := cache.Refresh("group1")
factory.AddEntry("string1")
factory.AddEntry("/string2/")
factory.Finish()
It("should ignore regex", func() {
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 1))
Expect(cache.Contains("string1", []string{"group1"})).Should(ConsistOf("group1"))
})
})
When("Regex grouped cache is used", func() {
cache := stringcache.NewInMemoryGroupedRegexCache()
factory := cache.Refresh("group1")
factory.AddEntry("string1")
factory.AddEntry("/string2/")
factory.Finish()
It("should ignore non-regex", func() {
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 1))
Expect(cache.Contains("string1", []string{"group1"})).Should(BeEmpty())
Expect(cache.Contains("string2", []string{"group1"})).Should(ConsistOf("group1"))
Expect(cache.Contains("shouldalsomatchstring2", []string{"group1"})).Should(ConsistOf("group1"))
})
})
})
Describe("Cache refresh", func() {
When("cache with 2 groups was created", func() {
cache := stringcache.NewInMemoryGroupedStringCache()
factory := cache.Refresh("group1")
factory.AddEntry("g1")
factory.AddEntry("both")
factory.Finish()
factory = cache.Refresh("group2")
factory.AddEntry("g2")
factory.AddEntry("both")
factory.Finish()
It("should contain 4 elements in 2 groups", func() {
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 2))
Expect(cache.ElementCount("group2")).Should(BeNumerically("==", 2))
Expect(cache.Contains("g1", []string{"group1", "group2"})).Should(ConsistOf("group1"))
Expect(cache.Contains("g2", []string{"group1", "group2"})).Should(ConsistOf("group2"))
Expect(cache.Contains("both", []string{"group1", "group2"})).Should(ConsistOf("group1", "group2"))
})
It("Should replace group content on refresh", func() {
factory := cache.Refresh("group1")
factory.AddEntry("newString")
factory.Finish()
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 1))
Expect(cache.ElementCount("group2")).Should(BeNumerically("==", 2))
Expect(cache.Contains("g1", []string{"group1", "group2"})).Should(BeEmpty())
Expect(cache.Contains("newString", []string{"group1", "group2"})).Should(ConsistOf("group1"))
Expect(cache.Contains("g2", []string{"group1", "group2"})).Should(ConsistOf("group2"))
Expect(cache.Contains("both", []string{"group1", "group2"})).Should(ConsistOf("group2"))
})
})
})
})

View File

@ -8,23 +8,24 @@ import (
"github.com/0xERR0R/blocky/log"
)
type StringCache interface {
ElementCount() int
Contains(searchString string) bool
type stringCache interface {
elementCount() int
contains(searchString string) bool
}
type CacheFactory interface {
AddEntry(entry string)
Create() StringCache
type cacheFactory interface {
addEntry(entry string)
create() stringCache
count() int
}
type stringCache map[int]string
type stringMap map[int]string
func normalizeEntry(entry string) string {
return strings.ToLower(entry)
}
func (cache stringCache) ElementCount() int {
func (cache stringMap) elementCount() int {
count := 0
for k, v := range cache {
@ -34,7 +35,7 @@ func (cache stringCache) ElementCount() int {
return count
}
func (cache stringCache) Contains(searchString string) bool {
func (cache stringMap) contains(searchString string) bool {
normalized := normalizeEntry(searchString)
searchLen := len(normalized)
@ -57,9 +58,10 @@ func (cache stringCache) Contains(searchString string) bool {
type stringCacheFactory struct {
// temporary map which holds sorted slice of strings grouped by string length
tmp map[int][]string
cnt int
}
func newStringCacheFactory() CacheFactory {
func newStringCacheFactory() cacheFactory {
return &stringCacheFactory{
tmp: make(map[int][]string),
}
@ -73,6 +75,10 @@ func (s *stringCacheFactory) getBucket(length int) []string {
return s.tmp[length]
}
func (s *stringCacheFactory) count() int {
return s.cnt
}
func (s *stringCacheFactory) insertString(entry string) {
normalized := normalizeEntry(entry)
entryLen := len(normalized)
@ -92,15 +98,16 @@ func (s *stringCacheFactory) insertString(entry string) {
}
}
func (s *stringCacheFactory) AddEntry(entry string) {
// skip empty strings
if len(entry) > 0 {
func (s *stringCacheFactory) addEntry(entry string) {
// skip empty strings and regex
if len(entry) > 0 && !isRegex(entry) {
s.cnt++
s.insertString(entry)
}
}
func (s *stringCacheFactory) Create() StringCache {
cache := make(stringCache, len(s.tmp))
func (s *stringCacheFactory) create() stringCache {
cache := make(stringMap, len(s.tmp))
for k, v := range s.tmp {
cache[k] = strings.Join(v, "")
}
@ -110,13 +117,17 @@ func (s *stringCacheFactory) Create() StringCache {
return cache
}
func isRegex(s string) bool {
return strings.HasPrefix(s, "/") && strings.HasSuffix(s, "/")
}
type regexCache []*regexp.Regexp
func (cache regexCache) ElementCount() int {
func (cache regexCache) elementCount() int {
return len(cache)
}
func (cache regexCache) Contains(searchString string) bool {
func (cache regexCache) contains(searchString string) bool {
for _, regex := range cache {
if regex.MatchString(searchString) {
log.PrefixedLog("regexCache").Debugf("regex '%s' matched with '%s'", regex, searchString)
@ -132,71 +143,29 @@ type regexCacheFactory struct {
cache regexCache
}
func (r *regexCacheFactory) AddEntry(entry string) {
compile, err := regexp.Compile(entry)
if err != nil {
log.Log().Warnf("invalid regex '%s'", entry)
} else {
r.cache = append(r.cache, compile)
func (r *regexCacheFactory) addEntry(entry string) {
if isRegex(entry) {
entry = strings.TrimSpace(entry[1 : len(entry)-1])
compile, err := regexp.Compile(entry)
if err != nil {
log.Log().Warnf("invalid regex '%s'", entry)
} else {
r.cache = append(r.cache, compile)
}
}
}
func (r *regexCacheFactory) Create() StringCache {
func (r *regexCacheFactory) count() int {
return len(r.cache)
}
func (r *regexCacheFactory) create() stringCache {
return r.cache
}
func newRegexCacheFactory() CacheFactory {
func newRegexCacheFactory() cacheFactory {
return &regexCacheFactory{
cache: make(regexCache, 0),
}
}
type chainedCache struct {
caches []StringCache
}
func (cache chainedCache) ElementCount() int {
sum := 0
for _, c := range cache.caches {
sum += c.ElementCount()
}
return sum
}
func (cache chainedCache) Contains(searchString string) bool {
for _, c := range cache.caches {
if c.Contains(searchString) {
return true
}
}
return false
}
type chainedCacheFactory struct {
stringCacheFactory CacheFactory
regexCacheFactory CacheFactory
}
func (r *chainedCacheFactory) AddEntry(entry string) {
if strings.HasPrefix(entry, "/") && strings.HasSuffix(entry, "/") {
entry = strings.TrimSpace(entry[1 : len(entry)-1])
r.regexCacheFactory.AddEntry(entry)
} else {
r.stringCacheFactory.AddEntry(entry)
}
}
func (r *chainedCacheFactory) Create() StringCache {
return &chainedCache{
caches: []StringCache{r.stringCacheFactory.Create(), r.regexCacheFactory.Create()},
}
}
func NewChainedCacheFactory() CacheFactory {
return &chainedCacheFactory{
stringCacheFactory: newStringCacheFactory(),
regexCacheFactory: newRegexCacheFactory(),
}
}

View File

@ -14,10 +14,10 @@ func BenchmarkStringCache(b *testing.B) {
factory := newStringCacheFactory()
for _, s := range testdata {
factory.AddEntry(s)
factory.addEntry(s)
}
factory.Create()
factory.create()
}
}

View File

@ -9,28 +9,28 @@ var _ = Describe("Caches", func() {
Describe("String StringCache", func() {
When("string StringCache was created", func() {
factory := newStringCacheFactory()
factory.AddEntry("google.com")
factory.AddEntry("apple.com")
factory.AddEntry("")
factory.AddEntry("google.com")
factory.AddEntry("APPLe.com")
factory.addEntry("google.com")
factory.addEntry("apple.com")
factory.addEntry("")
factory.addEntry("google.com")
factory.addEntry("APPLe.com")
cache := factory.Create()
cache := factory.create()
It("should match if StringCache contains exact string", func() {
Expect(cache.Contains("apple.com")).Should(BeTrue())
Expect(cache.Contains("google.com")).Should(BeTrue())
Expect(cache.Contains("www.google.com")).Should(BeFalse())
Expect(cache.Contains("")).Should(BeFalse())
Expect(cache.contains("apple.com")).Should(BeTrue())
Expect(cache.contains("google.com")).Should(BeTrue())
Expect(cache.contains("www.google.com")).Should(BeFalse())
Expect(cache.contains("")).Should(BeFalse())
})
It("should match case-insensitive", func() {
Expect(cache.Contains("aPPle.com")).Should(BeTrue())
Expect(cache.Contains("google.COM")).Should(BeTrue())
Expect(cache.Contains("www.google.com")).Should(BeFalse())
Expect(cache.Contains("")).Should(BeFalse())
Expect(cache.contains("aPPle.com")).Should(BeTrue())
Expect(cache.contains("google.COM")).Should(BeTrue())
Expect(cache.contains("www.google.com")).Should(BeFalse())
Expect(cache.contains("")).Should(BeFalse())
})
It("should return correct element count", func() {
Expect(cache.ElementCount()).Should(Equal(2))
Expect(cache.elementCount()).Should(Equal(2))
})
})
})
@ -38,52 +38,30 @@ var _ = Describe("Caches", func() {
Describe("Regex StringCache", func() {
When("regex StringCache was created", func() {
factory := newRegexCacheFactory()
factory.AddEntry(".*google.com")
factory.AddEntry("^apple\\.(de|com)$")
factory.AddEntry("amazon")
factory.addEntry("/.*google.com/")
factory.addEntry("/^apple\\.(de|com)$/")
factory.addEntry("/amazon/")
// this is not a regex, will be ignored
factory.AddEntry("(wrongRegex")
cache := factory.Create()
factory.addEntry("/(wrongRegex/")
factory.addEntry("plaintext")
cache := factory.create()
It("should match if one regex in StringCache matches string", func() {
Expect(cache.Contains("google.com")).Should(BeTrue())
Expect(cache.Contains("google.coma")).Should(BeTrue())
Expect(cache.Contains("agoogle.com")).Should(BeTrue())
Expect(cache.Contains("www.google.com")).Should(BeTrue())
Expect(cache.Contains("apple.com")).Should(BeTrue())
Expect(cache.Contains("apple.de")).Should(BeTrue())
Expect(cache.Contains("apple.it")).Should(BeFalse())
Expect(cache.Contains("www.apple.com")).Should(BeFalse())
Expect(cache.Contains("applecom")).Should(BeFalse())
Expect(cache.Contains("www.amazon.com")).Should(BeTrue())
Expect(cache.Contains("amazon.com")).Should(BeTrue())
Expect(cache.Contains("myamazon.com")).Should(BeTrue())
Expect(cache.contains("google.com")).Should(BeTrue())
Expect(cache.contains("google.coma")).Should(BeTrue())
Expect(cache.contains("agoogle.com")).Should(BeTrue())
Expect(cache.contains("www.google.com")).Should(BeTrue())
Expect(cache.contains("apple.com")).Should(BeTrue())
Expect(cache.contains("apple.de")).Should(BeTrue())
Expect(cache.contains("apple.it")).Should(BeFalse())
Expect(cache.contains("www.apple.com")).Should(BeFalse())
Expect(cache.contains("applecom")).Should(BeFalse())
Expect(cache.contains("www.amazon.com")).Should(BeTrue())
Expect(cache.contains("amazon.com")).Should(BeTrue())
Expect(cache.contains("myamazon.com")).Should(BeTrue())
})
It("should return correct element count", func() {
Expect(cache.ElementCount()).Should(Equal(3))
})
})
})
Describe("Chained StringCache", func() {
When("chained StringCache was created", func() {
factory := NewChainedCacheFactory()
factory.AddEntry("/.*google.com/")
factory.AddEntry("/^apple\\.(de|com)$/")
factory.AddEntry("amazon.com")
cache := factory.Create()
It("should match if one regex in StringCache matches string", func() {
Expect(cache.Contains("google.com")).Should(BeTrue())
Expect(cache.Contains("google.coma")).Should(BeTrue())
Expect(cache.Contains("agoogle.com")).Should(BeTrue())
Expect(cache.Contains("www.google.com")).Should(BeTrue())
Expect(cache.Contains("apple.com")).Should(BeTrue())
Expect(cache.Contains("amazon.com")).Should(BeTrue())
Expect(cache.Contains("apple.de")).Should(BeTrue())
Expect(cache.Contains("www.apple.com")).Should(BeFalse())
Expect(cache.Contains("applecom")).Should(BeFalse())
})
It("should return correct element count", func() {
Expect(cache.ElementCount()).Should(Equal(3))
Expect(factory.count()).Should(Equal(3))
Expect(cache.elementCount()).Should(Equal(3))
})
})
})

View File

@ -9,7 +9,6 @@ import (
"net"
"os"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
@ -36,13 +35,12 @@ type ListCacheType int
// Matcher checks if a domain is in a list
type Matcher interface {
// Match matches passed domain name against cached list entries
Match(domain string, groupsToCheck []string) (found bool, group string)
Match(domain string, groupsToCheck []string) (groups []string)
}
// ListCache generic cache of strings divided in groups
type ListCache struct {
groupCaches map[string]stringcache.StringCache
lock sync.RWMutex
groupedCache stringcache.GroupedStringCache
groupToLinks map[string][]string
refreshPeriod time.Duration
@ -55,12 +53,10 @@ type ListCache struct {
func (b *ListCache) LogConfig(logger *logrus.Entry) {
var total int
b.lock.RLock()
defer b.lock.RUnlock()
for group, cache := range b.groupCaches {
logger.Infof("%s: %d entries", group, cache.ElementCount())
total += cache.ElementCount()
for group := range b.groupToLinks {
count := b.groupedCache.ElementCount(group)
logger.Infof("%s: %d entries", group, count)
total += count
}
logger.Infof("TOTAL: %d entries", total)
@ -70,15 +66,16 @@ func (b *ListCache) LogConfig(logger *logrus.Entry) {
func NewListCache(t ListCacheType, groupToLinks map[string][]string, refreshPeriod time.Duration,
downloader FileDownloader, processingConcurrency uint, async bool,
) (*ListCache, error) {
groupCaches := make(map[string]stringcache.StringCache)
if processingConcurrency == 0 {
processingConcurrency = defaultProcessingConcurrency
}
b := &ListCache{
groupedCache: stringcache.NewChainedGroupedCache(
stringcache.NewInMemoryGroupedStringCache(),
stringcache.NewInMemoryGroupedRegexCache(),
),
groupToLinks: groupToLinks,
groupCaches: groupCaches,
refreshPeriod: refreshPeriod,
downloader: downloader,
listType: t,
@ -122,10 +119,8 @@ func logger() *logrus.Entry {
// downloads and reads files with domain names and creates cache for them
//
//nolint:funlen // will refactor in a later commit
func (b *ListCache) createCacheForGroup(links []string) (stringcache.StringCache, error) {
var err error
factory := stringcache.NewChainedCacheFactory()
func (b *ListCache) createCacheForGroup(group string, links []string) (created bool, err error) {
groupFactory := b.groupedCache.Refresh(group)
fileLinesChan := make(chan string, chanCap)
errChan := make(chan error, chanCap)
@ -166,12 +161,12 @@ Loop:
for {
select {
case line := <-fileLinesChan:
factory.AddEntry(line)
groupFactory.AddEntry(line)
case e := <-errChan:
var transientErr *TransientError
if errors.As(e, &transientErr) {
return nil, e
return false, e
}
err = multierror.Append(err, e)
case <-workerDoneChan:
@ -184,26 +179,18 @@ Loop:
}
}
cache := factory.Create()
if cache.ElementCount() == 0 && err != nil {
cache = nil // don't replace existing cache
if groupFactory.Count() == 0 && err != nil {
return false, err
}
return cache, err
groupFactory.Finish()
return true, err
}
// Match matches passed domain name against cached list entries
func (b *ListCache) Match(domain string, groupsToCheck []string) (found bool, group string) {
b.lock.RLock()
defer b.lock.RUnlock()
for _, g := range groupsToCheck {
if c, ok := b.groupCaches[g]; ok && c.Contains(domain) {
return true, g
}
}
return false, ""
func (b *ListCache) Match(domain string, groupsToCheck []string) (groups []string) {
return b.groupedCache.Contains(domain, groupsToCheck)
}
// Refresh triggers the refresh of a list
@ -215,20 +202,20 @@ func (b *ListCache) refresh(isInit bool) error {
var err error
for group, links := range b.groupToLinks {
cacheForGroup, e := b.createCacheForGroup(links)
created, e := b.createCacheForGroup(group, links)
if e != nil {
err = multierror.Append(err, multierror.Prefix(e, fmt.Sprintf("can't create cache group '%s':", group)))
}
if cacheForGroup == nil {
count := b.groupElementCount(group, isInit)
count := b.groupedCache.ElementCount(group)
if !created {
logger := logger().WithFields(logrus.Fields{
"group": group,
"total_count": count,
})
if count == 0 {
if count == 0 || isInit {
logger.Warn("Populating of group cache failed, cache will be empty until refresh succeeds")
} else {
logger.Warn("Populating of group cache failed, using existing cache, if any")
@ -237,37 +224,17 @@ func (b *ListCache) refresh(isInit bool) error {
continue
}
b.lock.Lock()
b.groupCaches[group] = cacheForGroup
b.lock.Unlock()
evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, cacheForGroup.ElementCount())
evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, count)
logger().WithFields(logrus.Fields{
"group": group,
"total_count": cacheForGroup.ElementCount(),
"total_count": count,
}).Info("group import finished")
}
return err
}
func (b *ListCache) groupElementCount(group string, isInit bool) int {
if isInit {
return 0
}
b.lock.RLock()
oldCache, ok := b.groupCaches[group]
b.lock.RUnlock()
if !ok {
return 0
}
return oldCache.ElementCount()
}
func readFile(file string) (io.ReadCloser, error) {
logger().WithField("file", file).Info("starting processing of file")
file = strings.TrimPrefix(file, "file://")

View File

@ -59,8 +59,7 @@ var _ = Describe("ListCache", func() {
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
found, group := sut.Match("", []string{"gr0"})
Expect(found).Should(BeFalse())
group := sut.Match("", []string{"gr0"})
Expect(group).Should(BeEmpty())
})
})
@ -73,8 +72,7 @@ var _ = Describe("ListCache", func() {
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
found, group := sut.Match("google.com", []string{"gr1"})
Expect(found).Should(BeFalse())
group := sut.Match("google.com", []string{"gr1"})
Expect(group).Should(BeEmpty())
})
})
@ -98,15 +96,13 @@ var _ = Describe("ListCache", func() {
)
Expect(err).Should(Succeed())
found, group := sut.Match("blocked1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group := sut.Match("blocked1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
err = sut.refresh(false)
Expect(err).Should(Succeed())
found, group = sut.Match("blocked1.com", []string{"gr1"})
Expect(found).Should(BeFalse())
group = sut.Match("blocked1.com", []string{"gr1"})
Expect(group).Should(BeEmpty())
})
})
@ -126,13 +122,11 @@ var _ = Describe("ListCache", func() {
defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
found, group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
found, group = sut.Match("inlinedomain2.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group = sut.Match("inlinedomain2.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
})
})
When("a temporary/transient err occurs on download", func() {
@ -159,26 +153,23 @@ var _ = Describe("ListCache", func() {
By("Lists loaded without timeout", func() {
Eventually(func(g Gomega) {
found, group := sut.Match("blocked1.com", []string{"gr1"})
g.Expect(found).Should(BeTrue())
g.Expect(group).Should(Equal("gr1"))
group := sut.Match("blocked1.com", []string{"gr1"})
g.Expect(group).Should(ContainElement("gr1"))
}, "1s").Should(Succeed())
})
Expect(sut.refresh(false)).Should(HaveOccurred())
By("List couldn't be loaded due to timeout", func() {
found, group := sut.Match("blocked1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group := sut.Match("blocked1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
})
sut.Refresh()
By("List couldn't be loaded due to timeout", func() {
found, group := sut.Match("blocked1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group := sut.Match("blocked1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
})
})
})
@ -199,17 +190,15 @@ var _ = Describe("ListCache", func() {
Expect(err).Should(Succeed())
By("Lists loaded without err", func() {
found, group := sut.Match("blocked1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group := sut.Match("blocked1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
})
Expect(sut.refresh(false)).Should(HaveOccurred())
By("Lists from first load is kept", func() {
found, group := sut.Match("blocked1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group := sut.Match("blocked1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
})
})
})
@ -222,17 +211,14 @@ var _ = Describe("ListCache", func() {
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false)
found, group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
Expect(group).Should(ContainElement("gr1"))
found, group = sut.Match("blocked1a.com", []string{"gr1", "gr2"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group = sut.Match("blocked1a.com", []string{"gr1", "gr2"})
Expect(group).Should(ContainElement("gr1"))
found, group = sut.Match("blocked1a.com", []string{"gr2"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr2"))
group = sut.Match("blocked1a.com", []string{"gr2"})
Expect(group).Should(ContainElement("gr2"))
})
})
When("Configuration has some faulty urls", func() {
@ -244,17 +230,14 @@ var _ = Describe("ListCache", func() {
sut, _ := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false)
found, group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
Expect(group).Should(ContainElement("gr1"))
found, group = sut.Match("blocked1a.com", []string{"gr1", "gr2"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group = sut.Match("blocked1a.com", []string{"gr1", "gr2"})
Expect(group).Should(ContainElements("gr1", "gr2"))
found, group = sut.Match("blocked1a.com", []string{"gr2"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr2"))
group = sut.Match("blocked1a.com", []string{"gr2"})
Expect(group).Should(ContainElement("gr2"))
})
})
When("List will be updated", func() {
@ -272,8 +255,7 @@ var _ = Describe("ListCache", func() {
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
found, group := sut.Match("blocked1.com", []string{})
Expect(found).Should(BeFalse())
group := sut.Match("blocked1.com", []string{})
Expect(group).Should(BeEmpty())
Expect(resultCnt).Should(Equal(3))
})
@ -288,20 +270,17 @@ var _ = Describe("ListCache", func() {
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(), defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
Expect(sut.groupCaches["gr1"].ElementCount()).Should(Equal(3))
Expect(sut.groupCaches["gr2"].ElementCount()).Should(Equal(2))
Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(3))
Expect(sut.groupedCache.ElementCount("gr2")).Should(Equal(2))
found, group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group := sut.Match("blocked1.com", []string{"gr1", "gr2"})
Expect(group).Should(ContainElement("gr1"))
found, group = sut.Match("blocked1a.com", []string{"gr1", "gr2"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group = sut.Match("blocked1a.com", []string{"gr1", "gr2"})
Expect(group).Should(ContainElement("gr1"))
found, group = sut.Match("blocked1a.com", []string{"gr2"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr2"))
group = sut.Match("blocked1a.com", []string{"gr2"})
Expect(group).Should(ContainElement("gr2"))
})
})
When("group with bigger files", func() {
@ -317,7 +296,7 @@ var _ = Describe("ListCache", func() {
defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
Expect(sut.groupCaches["gr1"].ElementCount()).Should(Equal(lines1 + lines2 + lines3))
Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(lines1 + lines2 + lines3))
})
})
When("inline list content is defined", func() {
@ -334,14 +313,12 @@ var _ = Describe("ListCache", func() {
defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
Expect(sut.groupCaches["gr1"].ElementCount()).Should(Equal(2))
found, group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(2))
group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
found, group = sut.Match("inlinedomain2.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group = sut.Match("inlinedomain2.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
})
})
When("Text file can't be parsed", func() {
@ -359,9 +336,8 @@ var _ = Describe("ListCache", func() {
defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
found, group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
})
})
When("Text file has too many errors", func() {
@ -390,9 +366,8 @@ var _ = Describe("ListCache", func() {
defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
found, group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
})
})
When("inline regex content is defined", func() {
@ -405,13 +380,11 @@ var _ = Describe("ListCache", func() {
defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
found, group := sut.Match("apple.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group := sut.Match("apple.com", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
found, group = sut.Match("apple.de", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
group = sut.Match("apple.de", []string{"gr1"})
Expect(group).Should(ContainElement("gr1"))
})
})
})

View File

@ -360,8 +360,8 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
domain := util.ExtractDomain(question)
logger := logger.WithField("domain", domain)
if whitelisted, group := r.matches(groupsToCheck, r.whitelistMatcher, domain); whitelisted {
logger.WithField("group", group).Debugf("domain is whitelisted")
if groups := r.matches(groupsToCheck, r.whitelistMatcher, domain); len(groups) > 0 {
logger.WithField("groups", groups).Debugf("domain is whitelisted")
resp, err := r.next.Resolve(request)
@ -374,8 +374,8 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
return true, resp, err
}
if blocked, group := r.matches(groupsToCheck, r.blacklistMatcher, domain); blocked {
resp, err := r.handleBlocked(logger, request, question, fmt.Sprintf("BLOCKED (%s)", group))
if groups := r.matches(groupsToCheck, r.blacklistMatcher, domain); len(groups) > 0 {
resp, err := r.handleBlocked(logger, request, question, fmt.Sprintf("BLOCKED (%s)", strings.Join(groups, ",")))
return true, resp, err
}
@ -404,10 +404,11 @@ func (r *BlockingResolver) Resolve(request *model.Request) (*model.Response, err
if len(entryToCheck) > 0 {
logger := logger.WithField("response_entry", entryToCheck)
if whitelisted, group := r.matches(groupsToCheck, r.whitelistMatcher, entryToCheck); whitelisted {
logger.WithField("group", group).Debugf("%s is whitelisted", tName)
} else if blocked, group := r.matches(groupsToCheck, r.blacklistMatcher, entryToCheck); blocked {
return r.handleBlocked(logger, request, request.Req.Question[0], fmt.Sprintf("BLOCKED %s (%s)", tName, group))
if groups := r.matches(groupsToCheck, r.whitelistMatcher, entryToCheck); len(groups) > 0 {
logger.WithField("groups", groups).Debugf("%s is whitelisted", tName)
} else if groups := r.matches(groupsToCheck, r.blacklistMatcher, entryToCheck); len(groups) > 0 {
return r.handleBlocked(logger, request, request.Req.Question[0], fmt.Sprintf("BLOCKED %s (%s)", tName,
strings.Join(groups, ",")))
}
}
}
@ -503,15 +504,12 @@ func (r *BlockingResolver) groupsToCheckForClient(request *model.Request) []stri
func (r *BlockingResolver) matches(groupsToCheck []string, m lists.Matcher,
domain string,
) (blocked bool, group string) {
) (group []string) {
if len(groupsToCheck) > 0 {
found, group := m.Match(domain, groupsToCheck)
if found {
return true, group
}
return m.Match(domain, groupsToCheck)
}
return false, ""
return []string{}
}
type blockHandler interface {