mirror of https://github.com/0xERR0R/blocky.git
feat(lists): add support for wildcard lists using a custom Trie (#1233)
This commit is contained in:
parent
f1a6fb0014
commit
b498bc5094
|
@ -12,5 +12,6 @@ todo.txt
|
|||
node_modules
|
||||
package-lock.json
|
||||
vendor/
|
||||
coverage.html
|
||||
coverage.txt
|
||||
coverage/
|
||||
|
|
5
Makefile
5
Makefile
|
@ -45,7 +45,7 @@ else
|
|||
go generate ./...
|
||||
endif
|
||||
|
||||
build: fmt generate ## Build binary
|
||||
build: generate ## Build binary
|
||||
go build $(GO_BUILD_FLAGS) -ldflags="$(GO_BUILD_LD_FLAGS)" -o $(GO_BUILD_OUTPUT)
|
||||
ifdef BIN_USER
|
||||
$(info setting owner of $(GO_BUILD_OUTPUT) to $(BIN_USER))
|
||||
|
@ -58,6 +58,7 @@ endif
|
|||
|
||||
test: ## run tests
|
||||
go run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!e2e" --coverprofile=coverage.txt --covermode=atomic --cover -r ${GINKGO_PROCS}
|
||||
go tool cover -html coverage.txt -o coverage.html
|
||||
|
||||
e2e-test: ## run e2e tests
|
||||
docker buildx build \
|
||||
|
@ -81,7 +82,7 @@ fmt: ## gofmt and goimports all go files
|
|||
go run mvdan.cc/gofumpt -l -w -extra .
|
||||
find . -name '*.go' -exec go run golang.org/x/tools/cmd/goimports -w {} +
|
||||
|
||||
docker-build: generate ## Build docker image
|
||||
docker-build: generate ## Build docker image
|
||||
docker buildx build \
|
||||
--build-arg VERSION=${VERSION} \
|
||||
--build-arg BUILD_TIME=${BUILD_TIME} \
|
||||
|
|
|
@ -56,10 +56,14 @@ type chainedGroupFactory struct {
|
|||
cacheFactories []GroupFactory
|
||||
}
|
||||
|
||||
func (c *chainedGroupFactory) AddEntry(entry string) {
|
||||
func (c *chainedGroupFactory) AddEntry(entry string) bool {
|
||||
for _, factory := range c.cacheFactories {
|
||||
factory.AddEntry(entry)
|
||||
if factory.AddEntry(entry) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *chainedGroupFactory) Count() int {
|
||||
|
|
|
@ -24,6 +24,10 @@ var _ = Describe("Chained grouped cache", func() {
|
|||
It("should not find any string", func() {
|
||||
Expect(cache.Contains("searchString", []string{"someGroup"})).Should(BeEmpty())
|
||||
})
|
||||
|
||||
It("should not add entries", func() {
|
||||
Expect(cache.Refresh("group").AddEntry("test")).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
Describe("Delegation", func() {
|
||||
|
@ -44,12 +48,12 @@ var _ = Describe("Chained grouped cache", func() {
|
|||
})
|
||||
|
||||
It("factory has 4 elements (both caches)", func() {
|
||||
Expect(factory.Count()).Should(BeNumerically("==", 4))
|
||||
Expect(factory.Count()).Should(BeNumerically("==", 2))
|
||||
})
|
||||
|
||||
It("should have element count of 4", func() {
|
||||
factory.Finish()
|
||||
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 4))
|
||||
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 2))
|
||||
})
|
||||
|
||||
It("should find strings", func() {
|
||||
|
@ -80,25 +84,38 @@ var _ = Describe("Chained grouped cache", func() {
|
|||
})
|
||||
|
||||
It("should contain 4 elements in 2 groups", func() {
|
||||
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 4))
|
||||
Expect(cache.ElementCount("group2")).Should(BeNumerically("==", 4))
|
||||
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 2))
|
||||
Expect(cache.ElementCount("group2")).Should(BeNumerically("==", 2))
|
||||
Expect(cache.Contains("g1", []string{"group1", "group2"})).Should(ConsistOf("group1"))
|
||||
Expect(cache.Contains("g2", []string{"group1", "group2"})).Should(ConsistOf("group2"))
|
||||
Expect(cache.Contains("both", []string{"group1", "group2"})).Should(ConsistOf("group1", "group2"))
|
||||
})
|
||||
|
||||
It("Should replace group content on refresh", func() {
|
||||
It("should replace group content on refresh", func() {
|
||||
factory = cache.Refresh("group1")
|
||||
factory.AddEntry("newString")
|
||||
factory.Finish()
|
||||
|
||||
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 2))
|
||||
Expect(cache.ElementCount("group2")).Should(BeNumerically("==", 4))
|
||||
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 1))
|
||||
Expect(cache.ElementCount("group2")).Should(BeNumerically("==", 2))
|
||||
Expect(cache.Contains("g1", []string{"group1", "group2"})).Should(BeEmpty())
|
||||
Expect(cache.Contains("newString", []string{"group1", "group2"})).Should(ConsistOf("group1"))
|
||||
Expect(cache.Contains("g2", []string{"group1", "group2"})).Should(ConsistOf("group2"))
|
||||
Expect(cache.Contains("both", []string{"group1", "group2"})).Should(ConsistOf("group2"))
|
||||
})
|
||||
|
||||
It("should replace empty groups on refresh", func() {
|
||||
factory = cache.Refresh("group")
|
||||
factory.AddEntry("begone")
|
||||
factory.Finish()
|
||||
|
||||
Expect(cache.ElementCount("group")).Should(BeNumerically("==", 1))
|
||||
|
||||
factory = cache.Refresh("group")
|
||||
factory.Finish()
|
||||
|
||||
Expect(cache.ElementCount("group")).Should(BeNumerically("==", 0))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -15,7 +15,7 @@ type GroupedStringCache interface {
|
|||
|
||||
type GroupFactory interface {
|
||||
// AddEntry adds a new string to the factory to be added later to the cache groups.
|
||||
AddEntry(entry string)
|
||||
AddEntry(entry string) bool
|
||||
|
||||
// Count returns amount of processed string in the factory
|
||||
Count() int
|
||||
|
|
|
@ -24,6 +24,13 @@ func NewInMemoryGroupedRegexCache() *InMemoryGroupedCache {
|
|||
}
|
||||
}
|
||||
|
||||
func NewInMemoryGroupedWildcardCache() *InMemoryGroupedCache {
|
||||
return &InMemoryGroupedCache{
|
||||
caches: make(map[string]stringCache),
|
||||
factoryFn: newWildcardCacheFactory,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *InMemoryGroupedCache) ElementCount(group string) int {
|
||||
c.lock.RLock()
|
||||
cache, found := c.caches[group]
|
||||
|
@ -57,8 +64,13 @@ func (c *InMemoryGroupedCache) Refresh(group string) GroupFactory {
|
|||
factory: c.factoryFn(),
|
||||
finishFn: func(sc stringCache) {
|
||||
c.lock.Lock()
|
||||
c.caches[group] = sc
|
||||
c.lock.Unlock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if sc != nil {
|
||||
c.caches[group] = sc
|
||||
} else {
|
||||
delete(c.caches, group)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -68,8 +80,8 @@ type inMemoryGroupFactory struct {
|
|||
finishFn func(stringCache)
|
||||
}
|
||||
|
||||
func (c *inMemoryGroupFactory) AddEntry(entry string) {
|
||||
c.factory.addEntry(entry)
|
||||
func (c *inMemoryGroupFactory) AddEntry(entry string) bool {
|
||||
return c.factory.addEntry(entry)
|
||||
}
|
||||
|
||||
func (c *inMemoryGroupFactory) Count() int {
|
||||
|
|
|
@ -46,8 +46,8 @@ var _ = Describe("In-Memory grouped cache", func() {
|
|||
cache = stringcache.NewInMemoryGroupedStringCache()
|
||||
factory = cache.Refresh("group1")
|
||||
|
||||
factory.AddEntry("string1")
|
||||
factory.AddEntry("string2")
|
||||
Expect(factory.AddEntry("string1")).Should(BeTrue())
|
||||
Expect(factory.AddEntry("string2")).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("cache should still have 0 element, since finish was not executed", func() {
|
||||
|
@ -69,28 +69,13 @@ var _ = Describe("In-Memory grouped cache", func() {
|
|||
Expect(cache.Contains("string2", []string{"group1", "someOtherGroup"})).Should(ConsistOf("group1"))
|
||||
})
|
||||
})
|
||||
When("String grouped cache is used", func() {
|
||||
BeforeEach(func() {
|
||||
cache = stringcache.NewInMemoryGroupedStringCache()
|
||||
factory = cache.Refresh("group1")
|
||||
|
||||
factory.AddEntry("string1")
|
||||
factory.AddEntry("/string2/")
|
||||
factory.Finish()
|
||||
})
|
||||
|
||||
It("should ignore regex", func() {
|
||||
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 1))
|
||||
Expect(cache.Contains("string1", []string{"group1"})).Should(ConsistOf("group1"))
|
||||
})
|
||||
})
|
||||
When("Regex grouped cache is used", func() {
|
||||
BeforeEach(func() {
|
||||
cache = stringcache.NewInMemoryGroupedRegexCache()
|
||||
factory = cache.Refresh("group1")
|
||||
|
||||
factory.AddEntry("string1")
|
||||
factory.AddEntry("/string2/")
|
||||
Expect(factory.AddEntry("string1")).Should(BeFalse())
|
||||
Expect(factory.AddEntry("/string2/")).Should(BeTrue())
|
||||
factory.Finish()
|
||||
})
|
||||
|
||||
|
@ -101,6 +86,25 @@ var _ = Describe("In-Memory grouped cache", func() {
|
|||
Expect(cache.Contains("shouldalsomatchstring2", []string{"group1"})).Should(ConsistOf("group1"))
|
||||
})
|
||||
})
|
||||
When("Wildcard grouped cache is used", func() {
|
||||
BeforeEach(func() {
|
||||
cache = stringcache.NewInMemoryGroupedWildcardCache()
|
||||
factory = cache.Refresh("group1")
|
||||
|
||||
Expect(factory.AddEntry("string1")).Should(BeFalse())
|
||||
Expect(factory.AddEntry("/string2/")).Should(BeFalse())
|
||||
Expect(factory.AddEntry("*.string3")).Should(BeTrue())
|
||||
factory.Finish()
|
||||
})
|
||||
|
||||
It("should ignore non-wildcard", func() {
|
||||
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 1))
|
||||
Expect(cache.Contains("string1", []string{"group1"})).Should(BeEmpty())
|
||||
Expect(cache.Contains("string2", []string{"group1"})).Should(BeEmpty())
|
||||
Expect(cache.Contains("string3", []string{"group1"})).Should(ConsistOf("group1"))
|
||||
Expect(cache.Contains("shouldalsomatch.string3", []string{"group1"})).Should(ConsistOf("group1"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Cache refresh", func() {
|
||||
|
@ -109,13 +113,13 @@ var _ = Describe("In-Memory grouped cache", func() {
|
|||
cache = stringcache.NewInMemoryGroupedStringCache()
|
||||
factory = cache.Refresh("group1")
|
||||
|
||||
factory.AddEntry("g1")
|
||||
factory.AddEntry("both")
|
||||
Expect(factory.AddEntry("g1")).Should(BeTrue())
|
||||
Expect(factory.AddEntry("both")).Should(BeTrue())
|
||||
factory.Finish()
|
||||
|
||||
factory = cache.Refresh("group2")
|
||||
factory.AddEntry("g2")
|
||||
factory.AddEntry("both")
|
||||
Expect(factory.AddEntry("g2")).Should(BeTrue())
|
||||
Expect(factory.AddEntry("both")).Should(BeTrue())
|
||||
factory.Finish()
|
||||
})
|
||||
|
||||
|
@ -129,7 +133,7 @@ var _ = Describe("In-Memory grouped cache", func() {
|
|||
|
||||
It("Should replace group content on refresh", func() {
|
||||
factory = cache.Refresh("group1")
|
||||
factory.AddEntry("newString")
|
||||
Expect(factory.AddEntry("newString")).Should(BeTrue())
|
||||
factory.Finish()
|
||||
|
||||
Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 1))
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/trie"
|
||||
)
|
||||
|
||||
type stringCache interface {
|
||||
|
@ -14,7 +15,7 @@ type stringCache interface {
|
|||
}
|
||||
|
||||
type cacheFactory interface {
|
||||
addEntry(entry string)
|
||||
addEntry(entry string) bool
|
||||
create() stringCache
|
||||
count() int
|
||||
}
|
||||
|
@ -86,7 +87,7 @@ func (s *stringCacheFactory) insertString(entry string) {
|
|||
ix := sort.SearchStrings(bucket, normalized)
|
||||
|
||||
if !(ix < len(bucket) && bucket[ix] == normalized) {
|
||||
// extent internal bucket
|
||||
// extend internal bucket
|
||||
bucket = append(s.getBucket(entryLen), "")
|
||||
|
||||
// move elements to make place for the insertion
|
||||
|
@ -98,29 +99,30 @@ func (s *stringCacheFactory) insertString(entry string) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *stringCacheFactory) addEntry(entry string) {
|
||||
// skip empty strings and regex
|
||||
if len(entry) > 0 && !isRegex(entry) {
|
||||
s.cnt++
|
||||
s.insertString(entry)
|
||||
func (s *stringCacheFactory) addEntry(entry string) bool {
|
||||
if len(entry) == 0 {
|
||||
return true // invalid but handled
|
||||
}
|
||||
|
||||
s.cnt++
|
||||
s.insertString(entry)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *stringCacheFactory) create() stringCache {
|
||||
if len(s.tmp) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cache := make(stringMap, len(s.tmp))
|
||||
for k, v := range s.tmp {
|
||||
cache[k] = strings.Join(v, "")
|
||||
}
|
||||
|
||||
s.tmp = nil
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
func isRegex(s string) bool {
|
||||
return strings.HasPrefix(s, "/") && strings.HasSuffix(s, "/")
|
||||
}
|
||||
|
||||
type regexCache []*regexp.Regexp
|
||||
|
||||
func (cache regexCache) elementCount() int {
|
||||
|
@ -143,17 +145,24 @@ type regexCacheFactory struct {
|
|||
cache regexCache
|
||||
}
|
||||
|
||||
func (r *regexCacheFactory) addEntry(entry string) {
|
||||
if isRegex(entry) {
|
||||
entry = strings.TrimSpace(entry[1 : len(entry)-1])
|
||||
compile, err := regexp.Compile(entry)
|
||||
|
||||
if err != nil {
|
||||
log.Log().Warnf("invalid regex '%s'", entry)
|
||||
} else {
|
||||
r.cache = append(r.cache, compile)
|
||||
}
|
||||
func (r *regexCacheFactory) addEntry(entry string) bool {
|
||||
if !strings.HasPrefix(entry, "/") || !strings.HasSuffix(entry, "/") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Trim slashes
|
||||
entry = strings.TrimSpace(entry[1 : len(entry)-1])
|
||||
|
||||
compile, err := regexp.Compile(entry)
|
||||
if err != nil {
|
||||
log.Log().Warnf("invalid regex '%s'", entry)
|
||||
|
||||
return true // invalid but handled
|
||||
}
|
||||
|
||||
r.cache = append(r.cache, compile)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *regexCacheFactory) count() int {
|
||||
|
@ -161,6 +170,10 @@ func (r *regexCacheFactory) count() int {
|
|||
}
|
||||
|
||||
func (r *regexCacheFactory) create() stringCache {
|
||||
if len(r.cache) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.cache
|
||||
}
|
||||
|
||||
|
@ -169,3 +182,68 @@ func newRegexCacheFactory() cacheFactory {
|
|||
cache: make(regexCache, 0),
|
||||
}
|
||||
}
|
||||
|
||||
type wildcardCache struct {
|
||||
trie trie.Trie
|
||||
cnt int
|
||||
}
|
||||
|
||||
func (cache wildcardCache) elementCount() int {
|
||||
return cache.cnt
|
||||
}
|
||||
|
||||
func (cache wildcardCache) contains(domain string) bool {
|
||||
return cache.trie.HasParentOf(domain)
|
||||
}
|
||||
|
||||
type wildcardCacheFactory struct {
|
||||
trie *trie.Trie
|
||||
cnt int
|
||||
}
|
||||
|
||||
func newWildcardCacheFactory() cacheFactory {
|
||||
return &wildcardCacheFactory{
|
||||
trie: trie.NewTrie(trie.SplitTLD),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *wildcardCacheFactory) addEntry(entry string) bool {
|
||||
globCount := strings.Count(entry, "*")
|
||||
if globCount == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(entry, "*.") || globCount > 1 {
|
||||
log.Log().Warnf("unsupported wildcard '%s': must start with '*.' and contain no other '*'", entry)
|
||||
|
||||
return true // invalid but handled
|
||||
}
|
||||
|
||||
entry = normalizeWildcard(entry)
|
||||
|
||||
r.trie.Insert(entry)
|
||||
r.cnt++
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *wildcardCacheFactory) count() int {
|
||||
return r.cnt
|
||||
}
|
||||
|
||||
func (r *wildcardCacheFactory) create() stringCache {
|
||||
if r.cnt == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return wildcardCache{*r.trie, r.cnt}
|
||||
}
|
||||
|
||||
func normalizeWildcard(domain string) string {
|
||||
domain = normalizeEntry(domain)
|
||||
domain = strings.TrimLeft(domain, "*")
|
||||
domain = strings.Trim(domain, ".")
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
return domain
|
||||
}
|
||||
|
|
|
@ -1,44 +1,248 @@
|
|||
package stringcache
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/0xERR0R/blocky/lists/parsers"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
)
|
||||
|
||||
func BenchmarkStringCache(b *testing.B) {
|
||||
testdata := createTestdata(10_000)
|
||||
var (
|
||||
// String and Wildcard benchmarks don't use the exact same data,
|
||||
// but since it's two versions of the same list it's closer to
|
||||
// the real world: we build the cache using different sources, but then check
|
||||
// the same list of domains.
|
||||
//
|
||||
// It is possible to run the benchmarks using the exact same data: set `useRealLists`
|
||||
// to `false`. The results should be similar to the current ones, with memory use
|
||||
// changing the most.
|
||||
useRealLists = true
|
||||
|
||||
regexTestData []string
|
||||
stringTestData []string
|
||||
wildcardTestData []string
|
||||
|
||||
baseMemStats runtime.MemStats
|
||||
)
|
||||
|
||||
func init() { //nolint:gochecknoinits
|
||||
// If you update either list, make sure both are the list version (see file header).
|
||||
stringTestData = loadTestdata("../../helpertest/data/oisd-big-plain.txt")
|
||||
|
||||
if useRealLists {
|
||||
wildcardTestData = loadTestdata("../../helpertest/data/oisd-big-wildcard.txt")
|
||||
|
||||
// Domain is in plain but not wildcard list, add it so `benchmarkCache` doesn't fail
|
||||
wildcardTestData = append(wildcardTestData, "*.btest.oisd.nl")
|
||||
} else {
|
||||
wildcardTestData = make([]string, 0, len(stringTestData))
|
||||
|
||||
for _, domain := range stringTestData {
|
||||
wildcardTestData = append(wildcardTestData, "*."+domain)
|
||||
}
|
||||
}
|
||||
|
||||
// OISD regex list is the exact same as the wildcard one, just using a different format
|
||||
regexTestData = make([]string, 0, len(wildcardTestData))
|
||||
|
||||
for _, wildcard := range wildcardTestData {
|
||||
domain := strings.TrimPrefix(wildcard, "*.")
|
||||
|
||||
// /^(.*\.)?subdomain\.example\.com$/
|
||||
regex := fmt.Sprintf(`/^(.*\.)?%s$/`, regexp.QuoteMeta(domain))
|
||||
|
||||
regexTestData = append(regexTestData, regex)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Cache Building ---
|
||||
//
|
||||
// Most memory efficient: Wildcard (blocky/trie radix) because of peak
|
||||
// Fastest: Wildcard (blocky/trie original)
|
||||
//
|
||||
//nolint:lll
|
||||
// BenchmarkRegexFactory-8 1 1 253 023 507 ns/op 430.60 fact_heap_MB 430.60 peak_heap_MB 1 792 669 024 B/op 9 826 986 allocs/op
|
||||
// BenchmarkStringFactory-8 7 163 969 933 ns/op 11.79 fact_heap_MB 26.91 peak_heap_MB 67 613 890 B/op 1 306 allocs/op
|
||||
// BenchmarkWildcardFactory-8 19 60 592 988 ns/op 16.60 fact_heap_MB 16.60 peak_heap_MB 26 740 317 B/op 92 245 allocs/op (original)
|
||||
// BenchmarkWildcardFactory-8 16 65 179 284 ns/op 14.92 fact_heap_MB 14.92 peak_heap_MB 27 997 734 B/op 52 937 allocs/op (radix)
|
||||
|
||||
func BenchmarkRegexFactory(b *testing.B) {
|
||||
benchmarkRegexFactory(b, newRegexCacheFactory)
|
||||
}
|
||||
|
||||
func BenchmarkStringFactory(b *testing.B) {
|
||||
benchmarkStringFactory(b, newStringCacheFactory)
|
||||
}
|
||||
|
||||
func BenchmarkWildcardFactory(b *testing.B) {
|
||||
benchmarkWildcardFactory(b, newWildcardCacheFactory)
|
||||
}
|
||||
|
||||
func benchmarkRegexFactory(b *testing.B, newFactory func() cacheFactory) {
|
||||
benchmarkFactory(b, regexTestData, newFactory)
|
||||
}
|
||||
|
||||
func benchmarkStringFactory(b *testing.B, newFactory func() cacheFactory) {
|
||||
benchmarkFactory(b, stringTestData, newFactory)
|
||||
}
|
||||
|
||||
func benchmarkWildcardFactory(b *testing.B, newFactory func() cacheFactory) {
|
||||
benchmarkFactory(b, wildcardTestData, newFactory)
|
||||
}
|
||||
|
||||
func benchmarkFactory(b *testing.B, data []string, newFactory func() cacheFactory) {
|
||||
baseMemStats = readMemStats()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
var (
|
||||
factory cacheFactory
|
||||
cache stringCache
|
||||
)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
factory := newStringCacheFactory()
|
||||
factory = newFactory()
|
||||
|
||||
for _, s := range testdata {
|
||||
factory.addEntry(s)
|
||||
for _, s := range data {
|
||||
if !factory.addEntry(s) {
|
||||
b.Fatalf("cache didn't insert value: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
factory.create()
|
||||
cache = factory.create()
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
reportMemUsage(b, "peak", factory, cache)
|
||||
reportMemUsage(b, "fact", factory) // cache will be GC'd
|
||||
}
|
||||
|
||||
// --- Cache Querying ---
|
||||
//
|
||||
// Most memory efficient: Wildcard (blocky/trie radix)
|
||||
// Fastest: Wildcard (blocky/trie original)
|
||||
//
|
||||
//nolint:lll
|
||||
// BenchmarkStringCache-8 6 204 754 798 ns/op 15.11 cache_heap_MB 0 B/op 0 allocs/op
|
||||
// BenchmarkWildcardCache-8 14 76 186 334 ns/op 16.61 cache_heap_MB 0 B/op 0 allocs/op (original)
|
||||
// BenchmarkWildcardCache-8 12 95 316 121 ns/op 14.91 cache_heap_MB 0 B/op 0 allocs/op (radix)
|
||||
|
||||
// Regex search is too slow to even complete
|
||||
// func BenchmarkRegexCache(b *testing.B) {
|
||||
// benchmarkRegexCache(b, newRegexCacheFactory)
|
||||
// }
|
||||
|
||||
func BenchmarkStringCache(b *testing.B) {
|
||||
benchmarkStringCache(b, newStringCacheFactory)
|
||||
}
|
||||
|
||||
func BenchmarkWildcardCache(b *testing.B) {
|
||||
benchmarkWildcardCache(b, newWildcardCacheFactory)
|
||||
}
|
||||
|
||||
// func benchmarkRegexCache(b *testing.B, newFactory func() cacheFactory) {
|
||||
// benchmarkCache(b, regexTestData, newFactory)
|
||||
// }
|
||||
|
||||
func benchmarkStringCache(b *testing.B, newFactory func() cacheFactory) {
|
||||
benchmarkCache(b, stringTestData, newFactory)
|
||||
}
|
||||
|
||||
func benchmarkWildcardCache(b *testing.B, newFactory func() cacheFactory) {
|
||||
benchmarkCache(b, wildcardTestData, newFactory)
|
||||
}
|
||||
|
||||
func benchmarkCache(b *testing.B, data []string, newFactory func() cacheFactory) {
|
||||
baseMemStats = readMemStats()
|
||||
|
||||
factory := newFactory()
|
||||
|
||||
for _, s := range data {
|
||||
factory.addEntry(s)
|
||||
}
|
||||
|
||||
cache := factory.create()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Always use the plain strings for search:
|
||||
// - wildcards and regexes need a plain string query
|
||||
// - all benchmarks will do the same number of queries
|
||||
for _, s := range stringTestData {
|
||||
if !cache.contains(s) {
|
||||
b.Fatalf("cache is missing value from stringTestData: %s", s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
reportMemUsage(b, "cache", cache)
|
||||
}
|
||||
|
||||
// ---
|
||||
|
||||
func readMemStats() (res runtime.MemStats) {
|
||||
runtime.GC()
|
||||
debug.FreeOSMemory()
|
||||
|
||||
runtime.ReadMemStats(&res)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func reportMemUsage(b *testing.B, prefix string, toKeepAllocated ...any) {
|
||||
m := readMemStats()
|
||||
|
||||
b.ReportMetric(toMB(m.HeapAlloc-baseMemStats.HeapAlloc), prefix+"_heap_MB")
|
||||
|
||||
// Forces Go to keep the values allocated, meaning we include them in the above measurement
|
||||
// You can tell it works because factory benchmarks have different values for both calls
|
||||
for i := range toKeepAllocated {
|
||||
toKeepAllocated[i] = nil
|
||||
}
|
||||
}
|
||||
|
||||
func randString(n int) string {
|
||||
const charPool = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-."
|
||||
func toMB(b uint64) float64 {
|
||||
const bytesInKB = float64(1024)
|
||||
|
||||
b := make([]byte, n)
|
||||
kb := float64(b) / bytesInKB
|
||||
|
||||
for i := range b {
|
||||
b[i] = charPool[rand.Intn(len(charPool))]
|
||||
}
|
||||
|
||||
return string(b)
|
||||
return math.Round(kb) / 1024
|
||||
}
|
||||
|
||||
func createTestdata(count int) []string {
|
||||
var result []string
|
||||
func loadTestdata(path string) (res []string) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
result = append(result, randString(8+rand.Intn(20)))
|
||||
p := parsers.AllowErrors(parsers.Hosts(f), parsers.NoErrorLimit)
|
||||
p.OnErr(func(err error) {
|
||||
log.Log().Warnf("could not parse line in %s: %s", path, err)
|
||||
})
|
||||
|
||||
err = parsers.ForEach[*parsers.HostsIterator](context.Background(), p, func(hosts *parsers.HostsIterator) error {
|
||||
return hosts.ForEach(func(host string) error {
|
||||
res = append(res, host)
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return result
|
||||
return res
|
||||
}
|
||||
|
|
|
@ -10,15 +10,29 @@ var _ = Describe("Caches", func() {
|
|||
cache stringCache
|
||||
factory cacheFactory
|
||||
)
|
||||
|
||||
Describe("String StringCache", func() {
|
||||
It("should not return a cache when empty", func() {
|
||||
Expect(newStringCacheFactory().create()).Should(BeNil())
|
||||
})
|
||||
|
||||
It("should recognise the empty string", func() {
|
||||
factory := newStringCacheFactory()
|
||||
|
||||
Expect(factory.addEntry("")).Should(BeTrue())
|
||||
|
||||
Expect(factory.count()).Should(BeNumerically("==", 0))
|
||||
Expect(factory.create()).Should(BeNil())
|
||||
})
|
||||
|
||||
When("string StringCache was created", func() {
|
||||
BeforeEach(func() {
|
||||
factory = newStringCacheFactory()
|
||||
factory.addEntry("google.com")
|
||||
factory.addEntry("apple.com")
|
||||
factory.addEntry("")
|
||||
factory.addEntry("google.com")
|
||||
factory.addEntry("APPLe.com")
|
||||
Expect(factory.addEntry("google.com")).Should(BeTrue())
|
||||
Expect(factory.addEntry("apple.com")).Should(BeTrue())
|
||||
Expect(factory.addEntry("")).Should(BeTrue()) // invalid, but handled
|
||||
Expect(factory.addEntry("google.com")).Should(BeTrue())
|
||||
Expect(factory.addEntry("APPLe.com")).Should(BeTrue())
|
||||
|
||||
cache = factory.create()
|
||||
})
|
||||
|
@ -29,30 +43,50 @@ var _ = Describe("Caches", func() {
|
|||
Expect(cache.contains("www.google.com")).Should(BeFalse())
|
||||
Expect(cache.contains("")).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("should match case-insensitive", func() {
|
||||
Expect(cache.contains("aPPle.com")).Should(BeTrue())
|
||||
Expect(cache.contains("google.COM")).Should(BeTrue())
|
||||
Expect(cache.contains("www.google.com")).Should(BeFalse())
|
||||
Expect(cache.contains("")).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("should return correct element count", func() {
|
||||
Expect(factory.count()).Should(Equal(4))
|
||||
Expect(cache.elementCount()).Should(Equal(2))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Regex StringCache", func() {
|
||||
It("should not return a cache when empty", func() {
|
||||
Expect(newRegexCacheFactory().create()).Should(BeNil())
|
||||
})
|
||||
|
||||
It("should recognise invalid regexes", func() {
|
||||
factory := newRegexCacheFactory()
|
||||
|
||||
Expect(factory.addEntry("/*/")).Should(BeTrue())
|
||||
Expect(factory.addEntry("/?/")).Should(BeTrue())
|
||||
Expect(factory.addEntry("/+/")).Should(BeTrue())
|
||||
Expect(factory.addEntry("/[/")).Should(BeTrue())
|
||||
|
||||
Expect(factory.count()).Should(BeNumerically("==", 0))
|
||||
Expect(factory.create()).Should(BeNil())
|
||||
})
|
||||
|
||||
When("regex StringCache was created", func() {
|
||||
BeforeEach(func() {
|
||||
factory = newRegexCacheFactory()
|
||||
factory.addEntry("/.*google.com/")
|
||||
factory.addEntry("/^apple\\.(de|com)$/")
|
||||
factory.addEntry("/amazon/")
|
||||
// this is not a regex, will be ignored
|
||||
factory.addEntry("/(wrongRegex/")
|
||||
factory.addEntry("plaintext")
|
||||
Expect(factory.addEntry("/.*google.com/")).Should(BeTrue())
|
||||
Expect(factory.addEntry("/^apple\\.(de|com)$/")).Should(BeTrue())
|
||||
Expect(factory.addEntry("/amazon/")).Should(BeTrue())
|
||||
Expect(factory.addEntry("/(wrongRegex/")).Should(BeTrue()) // recognized as regex but ignored because invalid
|
||||
Expect(factory.addEntry("plaintext")).Should(BeFalse())
|
||||
|
||||
cache = factory.create()
|
||||
})
|
||||
|
||||
It("should match if one regex in StringCache matches string", func() {
|
||||
Expect(cache.contains("google.com")).Should(BeTrue())
|
||||
Expect(cache.contains("google.coma")).Should(BeTrue())
|
||||
|
@ -67,10 +101,73 @@ var _ = Describe("Caches", func() {
|
|||
Expect(cache.contains("amazon.com")).Should(BeTrue())
|
||||
Expect(cache.contains("myamazon.com")).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("should return correct element count", func() {
|
||||
Expect(factory.count()).Should(Equal(3))
|
||||
Expect(cache.elementCount()).Should(Equal(3))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Wildcard StringCache", func() {
|
||||
It("should not return a cache when empty", func() {
|
||||
Expect(newWildcardCacheFactory().create()).Should(BeNil())
|
||||
})
|
||||
|
||||
It("should recognise invalid wildcards", func() {
|
||||
factory := newWildcardCacheFactory()
|
||||
|
||||
Expect(factory.addEntry("example.*.com")).Should(BeTrue())
|
||||
Expect(factory.addEntry("example.*")).Should(BeTrue())
|
||||
Expect(factory.addEntry("sub.*.example.com")).Should(BeTrue())
|
||||
Expect(factory.addEntry("*.example.*")).Should(BeTrue())
|
||||
|
||||
Expect(factory.count()).Should(BeNumerically("==", 0))
|
||||
Expect(factory.create()).Should(BeNil())
|
||||
})
|
||||
|
||||
When("cache was created", func() {
|
||||
BeforeEach(func() {
|
||||
factory = newWildcardCacheFactory()
|
||||
|
||||
Expect(factory.addEntry("*.example.com")).Should(BeTrue())
|
||||
Expect(factory.addEntry("*.example.org")).Should(BeTrue())
|
||||
Expect(factory.addEntry("*.blocked")).Should(BeTrue())
|
||||
Expect(factory.addEntry("*.sub.blocked")).Should(BeTrue()) // already handled by above
|
||||
|
||||
cache = factory.create()
|
||||
})
|
||||
|
||||
It("should match if one regex in StringCache matches string", func() {
|
||||
// first entry
|
||||
Expect(cache.contains("example.com")).Should(BeTrue())
|
||||
Expect(cache.contains("www.example.com")).Should(BeTrue())
|
||||
|
||||
// look alikes
|
||||
Expect(cache.contains("com")).Should(BeFalse())
|
||||
Expect(cache.contains("example.coma")).Should(BeFalse())
|
||||
Expect(cache.contains("an-example.com")).Should(BeFalse())
|
||||
Expect(cache.contains("examplecom")).Should(BeFalse())
|
||||
|
||||
// other entry
|
||||
Expect(cache.contains("example.org")).Should(BeTrue())
|
||||
Expect(cache.contains("www.example.org")).Should(BeTrue())
|
||||
|
||||
// unrelated
|
||||
Expect(cache.contains("example.net")).Should(BeFalse())
|
||||
Expect(cache.contains("www.example.net")).Should(BeFalse())
|
||||
|
||||
// third entry
|
||||
Expect(cache.contains("blocked")).Should(BeTrue())
|
||||
Expect(cache.contains("sub.blocked")).Should(BeTrue())
|
||||
Expect(cache.contains("sub.sub.blocked")).Should(BeTrue())
|
||||
Expect(cache.contains("example.blocked")).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("should return correct element count", func() {
|
||||
Expect(factory.count()).Should(Equal(4))
|
||||
Expect(cache.elementCount()).Should(Equal(4))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -71,8 +71,8 @@ blocking:
|
|||
- https://s3.amazonaws.com/lists.disconnect.me/simple_tracking.txt
|
||||
- |
|
||||
# inline definition with YAML literal block scalar style
|
||||
# hosts format
|
||||
someadsdomain.com
|
||||
*.example.com
|
||||
special:
|
||||
- https://raw.githubusercontent.com/StevenBlack/hosts/master/alternates/fakenews/hosts
|
||||
# definition of whitelist groups. Attention: if the same group has black and whitelists, whitelists will be used to disable particular blacklist entries. If a group has only whitelist entries -> this means only domains from this list are allowed, all other domains will be blocked
|
||||
|
|
|
@ -374,7 +374,8 @@ The supported list formats are:
|
|||
|
||||
1. the well-known [Hosts format](https://en.wikipedia.org/wiki/Hosts_(file))
|
||||
2. one domain per line (plain domain list)
|
||||
3. one regex per line
|
||||
3. one wildcard per line
|
||||
4. one regex per line
|
||||
|
||||
!!! example
|
||||
|
||||
|
@ -389,6 +390,7 @@ The supported list formats are:
|
|||
# content is in plain domain list format
|
||||
someadsdomain.com
|
||||
anotheradsdomain.com
|
||||
*.wildcard.example.com # blocks wildcard.example.com and all subdomains
|
||||
- |
|
||||
# inline definition with a regex
|
||||
/^banners?[_.-]/
|
||||
|
@ -414,6 +416,11 @@ The supported list formats are:
|
|||
!!! warning
|
||||
You must also define client group mapping, otherwise you black and whitelist definition will have no effect.
|
||||
|
||||
#### Wildcard support
|
||||
|
||||
You can use wildcards to block a domain and all its subdomains.
|
||||
Example: `*.example.com` will block `example.com` and `any.subdomains.example.com`.
|
||||
|
||||
#### Regex support
|
||||
|
||||
You can use regex to define patterns to block. A regex entry must start and end with the slash character (`/`). Some
|
||||
|
@ -423,6 +430,9 @@ Examples:
|
|||
- `/^baddomain/` will block `baddomain.com`, but not `www.baddomain.com`
|
||||
- `/^apple\.(de|com)$/` will only block `apple.de` and `apple.com`
|
||||
|
||||
!!! warning
|
||||
Regexes use more a lot more memory and are much slower than wildcards, you should use them as a last resort.
|
||||
|
||||
### Client groups
|
||||
|
||||
In this configuration section, you can define, which blocking group(s) should be used for which client in your network.
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -18,7 +18,10 @@ import (
|
|||
"github.com/ThinkChaos/parcour/jobgroup"
|
||||
)
|
||||
|
||||
const groupProducersBufferCap = 1000
|
||||
const (
|
||||
groupProducersBufferCap = 1000
|
||||
regexWarningThreshold = 500
|
||||
)
|
||||
|
||||
// ListCacheType represents the type of cached list ENUM(
|
||||
// blacklist // is a list with blocked domains
|
||||
|
@ -35,6 +38,7 @@ type Matcher interface {
|
|||
// ListCache generic cache of strings divided in groups
|
||||
type ListCache struct {
|
||||
groupedCache stringcache.GroupedStringCache
|
||||
regexCache stringcache.GroupedStringCache
|
||||
|
||||
cfg config.SourceLoadingConfig
|
||||
listType ListCacheType
|
||||
|
@ -44,12 +48,21 @@ type ListCache struct {
|
|||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (b *ListCache) LogConfig(logger *logrus.Entry) {
|
||||
var total int
|
||||
total := 0
|
||||
regexes := 0
|
||||
|
||||
for group := range b.groupSources {
|
||||
count := b.groupedCache.ElementCount(group)
|
||||
logger.Infof("%s: %d entries", group, count)
|
||||
total += count
|
||||
regexes += b.regexCache.ElementCount(group)
|
||||
}
|
||||
|
||||
if regexes > regexWarningThreshold {
|
||||
logger.Warnf(
|
||||
"REGEXES: %d !! High use of regexes is not recommended: they use a lot of memory and are very slow to search",
|
||||
regexes,
|
||||
)
|
||||
}
|
||||
|
||||
logger.Infof("TOTAL: %d entries", total)
|
||||
|
@ -60,11 +73,15 @@ func NewListCache(ctx context.Context,
|
|||
t ListCacheType, cfg config.SourceLoadingConfig,
|
||||
groupSources map[string][]config.BytesSource, downloader FileDownloader,
|
||||
) (*ListCache, error) {
|
||||
regexCache := stringcache.NewInMemoryGroupedRegexCache()
|
||||
|
||||
c := &ListCache{
|
||||
groupedCache: stringcache.NewChainedGroupedCache(
|
||||
stringcache.NewInMemoryGroupedStringCache(),
|
||||
stringcache.NewInMemoryGroupedRegexCache(),
|
||||
regexCache,
|
||||
stringcache.NewInMemoryGroupedWildcardCache(), // must be after regex which can contain '*'
|
||||
stringcache.NewInMemoryGroupedStringCache(), // accepts all values, must be last
|
||||
),
|
||||
regexCache: regexCache,
|
||||
|
||||
cfg: cfg,
|
||||
listType: t,
|
||||
|
@ -168,9 +185,11 @@ func (b *ListCache) createCacheForGroup(
|
|||
|
||||
producers.GoConsume(func(ctx context.Context, ch <-chan string) error {
|
||||
for host := range ch {
|
||||
hasEntries = true
|
||||
|
||||
groupFactory.AddEntry(host)
|
||||
if groupFactory.AddEntry(host) {
|
||||
hasEntries = true
|
||||
} else {
|
||||
logger().WithField("host", host).Warn("no list cache was able to use host")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -57,6 +57,7 @@ func (h *HostsIterator) UnmarshalText(data []byte) error {
|
|||
entries := []hostsIterator{
|
||||
new(HostListEntry),
|
||||
new(HostsFileEntry),
|
||||
new(WildcardEntry),
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
|
@ -200,6 +201,37 @@ func (e HostsFileEntry) forEachHost(callback func(string) error) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// WildcardEntry is single domain wildcard.
|
||||
type WildcardEntry string
|
||||
|
||||
func (e WildcardEntry) String() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// We assume this is used with `Lines`:
|
||||
// - data will never be empty
|
||||
// - comments are stripped
|
||||
func (e *WildcardEntry) UnmarshalText(data []byte) error {
|
||||
scanner := bufio.NewScanner(bytes.NewReader(data))
|
||||
scanner.Split(bufio.ScanWords)
|
||||
|
||||
_ = scanner.Scan() // data is not empty
|
||||
|
||||
entry := scanner.Text()
|
||||
|
||||
if !strings.HasPrefix(entry, "*.") || strings.Count(entry, "*") > 1 {
|
||||
return fmt.Errorf("unsupported wildcard '%s': must start with '*.' and contain no other '*'", entry)
|
||||
}
|
||||
|
||||
*e = WildcardEntry(entry)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e WildcardEntry) forEachHost(callback func(string) error) error {
|
||||
return callback(e.String())
|
||||
}
|
||||
|
||||
func normalizeHostsListEntry(host string) (string, error) {
|
||||
var err error
|
||||
// Lookup is the profile preferred for DNS queries, we use Punycode here as it does less validation.
|
||||
|
|
|
@ -36,6 +36,7 @@ var _ = Describe("Hosts", func() {
|
|||
`/domain\.(tld|local)/`,
|
||||
`/^(.*\.)?2023\.xn--aptslabs-6fd\.net$/`,
|
||||
`müller.com`,
|
||||
`*.example.com`,
|
||||
)
|
||||
})
|
||||
|
||||
|
@ -70,11 +71,16 @@ var _ = Describe("Hosts", func() {
|
|||
Expect(iteratorToList(it.ForEach)).Should(Equal([]string{`xn--mller-kva.com`}))
|
||||
Expect(sut.Position()).Should(Equal("line 8"))
|
||||
|
||||
it, err = sut.Next(context.Background())
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(iteratorToList(it.ForEach)).Should(Equal([]string{"*.example.com"}))
|
||||
Expect(sut.Position()).Should(Equal("line 9"))
|
||||
|
||||
_, err = sut.Next(context.Background())
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(err).Should(MatchError(io.EOF))
|
||||
Expect(IsNonResumableErr(err)).Should(BeTrue())
|
||||
Expect(sut.Position()).Should(Equal("line 9"))
|
||||
Expect(sut.Position()).Should(Equal("line 10"))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -85,6 +91,7 @@ var _ = Describe("Hosts", func() {
|
|||
"!notadomain!",
|
||||
"xn---mllerk1va.com",
|
||||
`/invalid regex ??/`,
|
||||
"invalid.*.wildcard",
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
|
|
|
@ -496,10 +496,9 @@ func (s *Server) printConfiguration() {
|
|||
runtime.ReadMemStats(&m)
|
||||
|
||||
logger().Infof(" memory:")
|
||||
logger().Infof(" alloc = %10v MB", toMB(m.Alloc))
|
||||
logger().Infof(" heapAlloc = %10v MB", toMB(m.HeapAlloc))
|
||||
logger().Infof(" sys = %10v MB", toMB(m.Sys))
|
||||
logger().Infof(" numGC = %10v", m.NumGC)
|
||||
logger().Infof(" heap = %10v MB", toMB(m.HeapAlloc))
|
||||
logger().Infof(" sys = %10v MB", toMB(m.Sys))
|
||||
logger().Infof(" numGC = %10v", m.NumGC)
|
||||
}
|
||||
|
||||
func toMB(b uint64) uint64 {
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
package trie
|
||||
|
||||
import "strings"
|
||||
|
||||
type SplitFunc func(string) (label, rest string)
|
||||
|
||||
// www.example.com -> ("com", "www.example")
|
||||
func SplitTLD(domain string) (label, rest string) {
|
||||
domain = strings.TrimRight(domain, ".")
|
||||
|
||||
idx := strings.LastIndexByte(domain, '.')
|
||||
if idx == -1 {
|
||||
return domain, ""
|
||||
}
|
||||
|
||||
label = domain[idx+1:]
|
||||
rest = domain[:idx]
|
||||
|
||||
return label, rest
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package trie
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("SpltTLD", func() {
|
||||
It("should split a tld", func() {
|
||||
key, rest := SplitTLD("www.example.com")
|
||||
Expect(key).Should(Equal("com"))
|
||||
Expect(rest).Should(Equal("www.example"))
|
||||
})
|
||||
|
||||
It("should not split a plain string", func() {
|
||||
key, rest := SplitTLD("example")
|
||||
Expect(key).Should(Equal("example"))
|
||||
Expect(rest).Should(Equal(""))
|
||||
})
|
||||
|
||||
It("should not crash with an empty string", func() {
|
||||
key, rest := SplitTLD("")
|
||||
Expect(key).Should(Equal(""))
|
||||
Expect(rest).Should(Equal(""))
|
||||
})
|
||||
|
||||
It("should ignore trailing dots", func() {
|
||||
key, rest := SplitTLD("www.example.com.")
|
||||
Expect(key).Should(Equal("com"))
|
||||
Expect(rest).Should(Equal("www.example"))
|
||||
|
||||
key, rest = SplitTLD(rest)
|
||||
Expect(key).Should(Equal("example"))
|
||||
Expect(rest).Should(Equal("www"))
|
||||
})
|
||||
|
||||
It("should skip empty parts", func() {
|
||||
key, rest := SplitTLD("www.example..com")
|
||||
Expect(key).Should(Equal("com"))
|
||||
Expect(rest).Should(Equal("www.example."))
|
||||
|
||||
key, rest = SplitTLD(rest)
|
||||
Expect(key).Should(Equal("example"))
|
||||
Expect(rest).Should(Equal("www"))
|
||||
})
|
||||
})
|
|
@ -0,0 +1,175 @@
|
|||
package trie
|
||||
|
||||
// Trie stores a set of strings and can quickly check
|
||||
// if it contains an element, or one of its parents.
|
||||
//
|
||||
// It implements a semi-radix/semi-compressed trie:
|
||||
// a node that would be a single child is merged with
|
||||
// its parent, if it is a terminal.
|
||||
//
|
||||
// The word "prefix" is avoided because in practice
|
||||
// we use the `Trie` with `SplitTLD` so parents are
|
||||
// suffixes even if in the datastructure they are
|
||||
// prefixes.
|
||||
type Trie struct {
|
||||
split SplitFunc
|
||||
root parent
|
||||
}
|
||||
|
||||
func NewTrie(split SplitFunc) *Trie {
|
||||
return &Trie{
|
||||
split: split,
|
||||
root: parent{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Trie) IsEmpty() bool {
|
||||
return t.root.children == nil
|
||||
}
|
||||
|
||||
func (t *Trie) Insert(key string) {
|
||||
t.root.insert(key, t.split)
|
||||
}
|
||||
|
||||
func (t *Trie) HasParentOf(key string) bool {
|
||||
return t.root.hasParentOf(key, t.split)
|
||||
}
|
||||
|
||||
type node interface {
|
||||
hasParentOf(key string, split SplitFunc) bool
|
||||
}
|
||||
|
||||
// We save memory by not keeping track of children of
|
||||
// nodes that are terminals (part of the set) as we only
|
||||
// ever need to know if a domain, or any of its parents,
|
||||
// is in the `Trie`.
|
||||
// Example: if the `Trie` contains "example.com", inserting
|
||||
// "www.example.com" has no effect as we already know it
|
||||
// is contained in the set.
|
||||
// Conversely, if it contains "www.example.com" and we insert
|
||||
// "example.com", then "www.example.com" is removed as it is
|
||||
// no longer useful.
|
||||
//
|
||||
// This means that all terminals are leafs and vice-versa.
|
||||
// So we save slightly more memory by avoiding a `isTerminal bool`
|
||||
// per parent.
|
||||
type parent struct {
|
||||
children map[string]node
|
||||
}
|
||||
|
||||
func newParent() *parent {
|
||||
return &parent{
|
||||
children: make(map[string]node, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (n *parent) insert(key string, split SplitFunc) {
|
||||
if len(key) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if n.children == nil {
|
||||
n.children = make(map[string]node, 1)
|
||||
}
|
||||
|
||||
label, rest := split(key)
|
||||
|
||||
child, ok := n.children[label]
|
||||
if !ok || len(rest) == 0 {
|
||||
n.children[label] = terminal(rest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
switch child := child.(type) {
|
||||
case *parent:
|
||||
// Continue down the trie
|
||||
key = rest
|
||||
n = child
|
||||
|
||||
continue
|
||||
|
||||
case terminal:
|
||||
if child.hasParentOf(rest, split) {
|
||||
// Found a parent/"prefix" in the set
|
||||
return
|
||||
}
|
||||
|
||||
p := newParent()
|
||||
n.children[label] = p
|
||||
|
||||
p.insert(child.String(), split) // keep existing terminal
|
||||
p.insert(rest, split) // add new value
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *parent) hasParentOf(key string, split SplitFunc) bool {
|
||||
for {
|
||||
label, rest := split(key)
|
||||
|
||||
child, ok := n.children[label]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch child := child.(type) {
|
||||
case *parent:
|
||||
if len(rest) == 0 {
|
||||
// The trie only contains children/"suffixes" of the
|
||||
// key we're searching for
|
||||
return false
|
||||
}
|
||||
|
||||
// Continue down the trie
|
||||
key = rest
|
||||
n = child
|
||||
|
||||
continue
|
||||
|
||||
case terminal:
|
||||
// Continue down the trie
|
||||
return child.hasParentOf(rest, split)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type terminal string
|
||||
|
||||
func (t terminal) String() string {
|
||||
return string(t)
|
||||
}
|
||||
|
||||
func (t terminal) hasParentOf(searchKey string, split SplitFunc) bool {
|
||||
tKey := t.String()
|
||||
if tKey == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
for {
|
||||
tLabel, tRest := split(tKey)
|
||||
|
||||
searchLabel, searchRest := split(searchKey)
|
||||
if searchLabel != tLabel {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(tRest) == 0 {
|
||||
// Found a parent/"prefix" in the set
|
||||
return true
|
||||
}
|
||||
|
||||
if len(searchRest) == 0 {
|
||||
// The trie only contains children/"suffixes" of the
|
||||
// key we're searching for
|
||||
return false
|
||||
}
|
||||
|
||||
// Continue down the trie
|
||||
searchKey = searchRest
|
||||
tKey = tRest
|
||||
}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
package trie
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestTrie(t *testing.T) {
|
||||
log.Silence()
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Trie Suite")
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
package trie
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Trie", func() {
|
||||
var sut *Trie
|
||||
|
||||
BeforeEach(func() {
|
||||
sut = NewTrie(SplitTLD)
|
||||
})
|
||||
|
||||
Describe("Basic operations", func() {
|
||||
When("Trie is created", func() {
|
||||
It("should be empty", func() {
|
||||
Expect(sut.IsEmpty()).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("should not find domains", func() {
|
||||
Expect(sut.HasParentOf("example.com")).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("should not insert the empty string", func() {
|
||||
sut.Insert("")
|
||||
Expect(sut.HasParentOf("")).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
When("Adding a domain", func() {
|
||||
var (
|
||||
domainOkTLD = "com"
|
||||
domainOk = "example." + domainOkTLD
|
||||
|
||||
domainKo = "example.org"
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
Expect(sut.HasParentOf(domainOk)).Should(BeFalse())
|
||||
sut.Insert(domainOk)
|
||||
Expect(sut.HasParentOf(domainOk)).Should(BeTrue())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(sut.HasParentOf(domainOk)).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("should be found", func() {})
|
||||
|
||||
It("should contain subdomains", func() {
|
||||
subdomain := "www." + domainOk
|
||||
|
||||
Expect(sut.HasParentOf(subdomain)).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("should support inserting subdomains", func() {
|
||||
subdomain := "www." + domainOk
|
||||
|
||||
Expect(sut.HasParentOf(subdomain)).Should(BeTrue())
|
||||
sut.Insert(subdomain)
|
||||
Expect(sut.HasParentOf(subdomain)).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("should not find unrelated", func() {
|
||||
Expect(sut.HasParentOf(domainKo)).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("should not find uninserted parent", func() {
|
||||
Expect(sut.HasParentOf(domainOkTLD)).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("should not find deep uninserted parent", func() {
|
||||
sut.Insert("sub.sub.sub.test")
|
||||
|
||||
Expect(sut.HasParentOf("sub.sub.test")).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("should find inserted parent", func() {
|
||||
sut.Insert(domainOkTLD)
|
||||
Expect(sut.HasParentOf(domainOkTLD)).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("should insert sibling", func() {
|
||||
sibling := "other." + domainOkTLD
|
||||
|
||||
sut.Insert(sibling)
|
||||
Expect(sut.HasParentOf(sibling)).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("should insert grand-children siblings", func() {
|
||||
base := "other.com"
|
||||
abcSub := "abc." + base
|
||||
xyzSub := "xyz." + base
|
||||
|
||||
sut.Insert(abcSub)
|
||||
Expect(sut.HasParentOf(abcSub)).Should(BeTrue())
|
||||
Expect(sut.HasParentOf(xyzSub)).Should(BeFalse())
|
||||
Expect(sut.HasParentOf(base)).Should(BeFalse())
|
||||
|
||||
sut.Insert(xyzSub)
|
||||
Expect(sut.HasParentOf(xyzSub)).Should(BeTrue())
|
||||
Expect(sut.HasParentOf(abcSub)).Should(BeTrue())
|
||||
Expect(sut.HasParentOf(base)).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
Loading…
Reference in New Issue