chore(refactor): refactor cache implementation (#1174)

* chore(refactor): refactor cache implementation

* chore: use atomic.Uint32 as prefetch names query count

Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>

---------

Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>
This commit is contained in:
Dimitri Herzog 2023-09-30 17:14:59 +02:00 committed by GitHub
parent d76740ea51
commit 497bd0d0fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 508 additions and 159 deletions

View File

@ -19,15 +19,18 @@ type element[T any] struct {
type ExpiringLRUCache[T any] struct { type ExpiringLRUCache[T any] struct {
cleanUpInterval time.Duration cleanUpInterval time.Duration
preExpirationFn OnExpirationCallback[T] preExpirationFn OnExpirationCallback[T]
onCacheHit OnCacheHitCallback
onCacheMiss OnCacheMissCallback
onAfterPut OnAfterPutCallback
lru *lru.Cache lru *lru.Cache
} }
type CacheOption[T any] func(c *ExpiringLRUCache[T]) type Options struct {
OnCacheHitFn OnCacheHitCallback
func WithCleanUpInterval[T any](d time.Duration) CacheOption[T] { OnCacheMissFn OnCacheMissCallback
return func(e *ExpiringLRUCache[T]) { OnAfterPutFn OnAfterPutCallback
e.cleanUpInterval = d CleanupInterval time.Duration
} MaxSize uint
} }
// OnExpirationCallback will be called just before an element gets expired and will // OnExpirationCallback will be called just before an element gets expired and will
@ -35,33 +38,56 @@ func WithCleanUpInterval[T any](d time.Duration) CacheOption[T] {
// element in the cache or nil to remove it // element in the cache or nil to remove it
type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration) type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration)
func WithOnExpiredFn[T any](fn OnExpirationCallback[T]) CacheOption[T] { // OnCacheHitCallback will be called on cache get if entry was found
return func(c *ExpiringLRUCache[T]) { type OnCacheHitCallback func(key string)
c.preExpirationFn = fn
} // OnCacheMissCallback will be called on cache get and entry was not found
type OnCacheMissCallback func(key string)
// OnAfterPutCallback will be called after put, receives new element count as parameter
type OnAfterPutCallback func(newSize int)
func NewCache[T any](options Options) *ExpiringLRUCache[T] {
return NewCacheWithOnExpired[T](options, nil)
} }
func WithMaxSize[T any](size uint) CacheOption[T] { func NewCacheWithOnExpired[T any](options Options,
return func(c *ExpiringLRUCache[T]) { onExpirationFn OnExpirationCallback[T],
if size > 0 { ) *ExpiringLRUCache[T] {
l, _ := lru.New(int(size))
c.lru = l
}
}
}
func NewCache[T any](options ...CacheOption[T]) *ExpiringLRUCache[T] {
l, _ := lru.New(defaultSize) l, _ := lru.New(defaultSize)
c := &ExpiringLRUCache[T]{ c := &ExpiringLRUCache[T]{
cleanUpInterval: defaultCleanUpInterval, cleanUpInterval: defaultCleanUpInterval,
preExpirationFn: func(key string) (val *T, ttl time.Duration) { preExpirationFn: func(key string) (val *T, ttl time.Duration) {
return nil, 0 return nil, 0
}, },
lru: l, onCacheHit: func(key string) {},
onCacheMiss: func(key string) {},
lru: l,
} }
for _, opt := range options { if options.CleanupInterval > 0 {
opt(c) c.cleanUpInterval = options.CleanupInterval
}
if options.MaxSize > 0 {
l, _ := lru.New(int(options.MaxSize))
c.lru = l
}
if options.OnAfterPutFn != nil {
c.onAfterPut = options.OnAfterPutFn
}
if options.OnCacheHitFn != nil {
c.onCacheHit = options.OnCacheHitFn
}
if options.OnCacheMissFn != nil {
c.onCacheMiss = options.OnCacheMissFn
}
if onExpirationFn != nil {
c.preExpirationFn = onExpirationFn
} }
go periodicCleanup(c) go periodicCleanup(c)
@ -122,15 +148,23 @@ func (e *ExpiringLRUCache[T]) Put(key string, val *T, ttl time.Duration) {
val: val, val: val,
expiresEpochMs: expiresEpochMs, expiresEpochMs: expiresEpochMs,
}) })
if e.onAfterPut != nil {
e.onAfterPut(e.lru.Len())
}
} }
func (e *ExpiringLRUCache[T]) Get(key string) (val *T, ttl time.Duration) { func (e *ExpiringLRUCache[T]) Get(key string) (val *T, ttl time.Duration) {
el, found := e.lru.Get(key) el, found := e.lru.Get(key)
if found { if found {
e.onCacheHit(key)
return el.(*element[T]).val, calculateRemainTTL(el.(*element[T]).expiresEpochMs) return el.(*element[T]).val, calculateRemainTTL(el.(*element[T]).expiresEpochMs)
} }
e.onCacheMiss(key)
return nil, 0 return nil, 0
} }

View File

@ -11,11 +11,11 @@ var _ = Describe("Expiration cache", func() {
Describe("Basic operations", func() { Describe("Basic operations", func() {
When("string cache was created", func() { When("string cache was created", func() {
It("Initial cache should be empty", func() { It("Initial cache should be empty", func() {
cache := NewCache[string]() cache := NewCache[string](Options{})
Expect(cache.TotalCount()).Should(Equal(0)) Expect(cache.TotalCount()).Should(Equal(0))
}) })
It("Initial cache should not contain any elements", func() { It("Initial cache should not contain any elements", func() {
cache := NewCache[string]() cache := NewCache[string](Options{})
val, expiration := cache.Get("key1") val, expiration := cache.Get("key1")
Expect(val).Should(BeNil()) Expect(val).Should(BeNil())
Expect(expiration).Should(Equal(time.Duration(0))) Expect(expiration).Should(Equal(time.Duration(0)))
@ -23,7 +23,7 @@ var _ = Describe("Expiration cache", func() {
}) })
When("Put new value with positive TTL", func() { When("Put new value with positive TTL", func() {
It("Should return the value before element expires", func() { It("Should return the value before element expires", func() {
cache := NewCache(WithCleanUpInterval[string](100 * time.Millisecond)) cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond})
v := "v1" v := "v1"
cache.Put("key1", &v, 50*time.Millisecond) cache.Put("key1", &v, 50*time.Millisecond)
val, expiration := cache.Get("key1") val, expiration := cache.Get("key1")
@ -33,7 +33,7 @@ var _ = Describe("Expiration cache", func() {
Expect(cache.TotalCount()).Should(Equal(1)) Expect(cache.TotalCount()).Should(Equal(1))
}) })
It("Should return nil after expiration", func() { It("Should return nil after expiration", func() {
cache := NewCache(WithCleanUpInterval[string](100 * time.Millisecond)) cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond})
v := "v1" v := "v1"
cache.Put("key1", &v, 50*time.Millisecond) cache.Put("key1", &v, 50*time.Millisecond)
@ -52,7 +52,7 @@ var _ = Describe("Expiration cache", func() {
}) })
When("Put new value without expiration", func() { When("Put new value without expiration", func() {
It("Should not cache the value", func() { It("Should not cache the value", func() {
cache := NewCache(WithCleanUpInterval[string](50 * time.Millisecond)) cache := NewCache[string](Options{CleanupInterval: 50 * time.Millisecond})
v := "x" v := "x"
cache.Put("key1", &v, 0) cache.Put("key1", &v, 0)
val, expiration := cache.Get("key1") val, expiration := cache.Get("key1")
@ -63,7 +63,7 @@ var _ = Describe("Expiration cache", func() {
}) })
When("Put updated value", func() { When("Put updated value", func() {
It("Should return updated value", func() { It("Should return updated value", func() {
cache := NewCache[string]() cache := NewCache[string](Options{})
v1 := "v1" v1 := "v1"
v2 := "v2" v2 := "v2"
cache.Put("key1", &v1, 50*time.Millisecond) cache.Put("key1", &v1, 50*time.Millisecond)
@ -79,7 +79,7 @@ var _ = Describe("Expiration cache", func() {
}) })
When("Purging after usage", func() { When("Purging after usage", func() {
It("Should be empty after purge", func() { It("Should be empty after purge", func() {
cache := NewCache[string]() cache := NewCache[string](Options{})
v1 := "y" v1 := "y"
cache.Put("key1", &v1, time.Second) cache.Put("key1", &v1, time.Second)
@ -91,15 +91,62 @@ var _ = Describe("Expiration cache", func() {
}) })
}) })
}) })
Describe("Hook functions", func() {
When("Hook functions are defined", func() {
It("should call each hook function", func() {
onCacheHitChannel := make(chan string, 10)
onCacheMissChannel := make(chan string, 10)
onAfterPutChannel := make(chan int, 10)
cache := NewCache[string](Options{
OnCacheHitFn: func(key string) {
onCacheHitChannel <- key
},
OnCacheMissFn: func(key string) {
onCacheMissChannel <- key
},
OnAfterPutFn: func(newSize int) {
onAfterPutChannel <- newSize
},
})
By("Get non existing value", func() {
val, _ := cache.Get("notExists")
Expect(val).Should(BeNil())
Expect(onCacheMissChannel).Should(Receive(Equal("notExists")))
Expect(onCacheHitChannel).Should(Not(Receive()))
Expect(onAfterPutChannel).Should(Not(Receive()))
})
By("Put new cache entry", func() {
v1 := "v1"
cache.Put("key1", &v1, time.Second)
Expect(onCacheMissChannel).Should(Not(Receive()))
Expect(onCacheMissChannel).Should(Not(Receive()))
Expect(onAfterPutChannel).Should(Receive(Equal(1)))
})
By("Get existing value", func() {
val, _ := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
Expect(onCacheMissChannel).Should(Not(Receive()))
Expect(onCacheHitChannel).Should(Receive(Equal("key1")))
Expect(onAfterPutChannel).Should(Not(Receive()))
})
})
})
})
Describe("preExpiration function", func() { Describe("preExpiration function", func() {
When(" function is defined", func() { When("function is defined", func() {
It("should update the value and TTL if function returns values", func() { It("should update the value and TTL if function returns values", func() {
fn := func(key string) (val *string, ttl time.Duration) { fn := func(key string) (val *string, ttl time.Duration) {
v2 := "v2" v2 := "v2"
return &v2, time.Second return &v2, time.Second
} }
cache := NewCache(WithOnExpiredFn(fn))
cache := NewCacheWithOnExpired[string](Options{}, fn)
v1 := "v1" v1 := "v1"
cache.Put("key1", &v1, 50*time.Millisecond) cache.Put("key1", &v1, 50*time.Millisecond)
@ -118,7 +165,7 @@ var _ = Describe("Expiration cache", func() {
return &v2, time.Second return &v2, time.Second
} }
cache := NewCache(WithOnExpiredFn(fn)) cache := NewCacheWithOnExpired[string](Options{}, fn)
v1 := "somval" v1 := "somval"
cache.Put("key1", &v1, time.Millisecond) cache.Put("key1", &v1, time.Millisecond)
@ -139,7 +186,7 @@ var _ = Describe("Expiration cache", func() {
fn := func(key string) (val *string, ttl time.Duration) { fn := func(key string) (val *string, ttl time.Duration) {
return nil, 0 return nil, 0
} }
cache := NewCache(WithCleanUpInterval[string](100*time.Millisecond), WithOnExpiredFn(fn)) cache := NewCacheWithOnExpired[string](Options{CleanupInterval: 100 * time.Microsecond}, fn)
v1 := "z" v1 := "z"
cache.Put("key1", &v1, 50*time.Millisecond) cache.Put("key1", &v1, 50*time.Millisecond)
@ -152,7 +199,7 @@ var _ = Describe("Expiration cache", func() {
Describe("LRU behaviour", func() { Describe("LRU behaviour", func() {
When("Defined max size is reached", func() { When("Defined max size is reached", func() {
It("should remove old elements", func() { It("should remove old elements", func() {
cache := NewCache(WithMaxSize[string](3)) cache := NewCache[string](Options{MaxSize: 3})
v1 := "val1" v1 := "val1"
v2 := "val2" v2 := "val2"

View File

@ -0,0 +1,127 @@
package expirationcache
import (
"sync/atomic"
"time"
)
type PrefetchingExpiringLRUCache[T any] struct {
cache ExpiringCache[cacheValue[T]]
prefetchingNameCache ExpiringCache[atomic.Uint32]
reloadFn ReloadEntryFn[T]
prefetchThreshold int
prefetchExpires time.Duration
onPrefetchEntryReloaded OnEntryReloadedCallback
onPrefetchCacheHit OnCacheHitCallback
}
type cacheValue[T any] struct {
element *T
prefetch bool
}
// OnEntryReloadedCallback will be called if a prefetched entry is reloaded
type OnEntryReloadedCallback func(key string)
// ReloadEntryFn reloads a prefetched entry by key
type ReloadEntryFn[T any] func(key string) (*T, time.Duration)
type PrefetchingOptions[T any] struct {
Options
ReloadFn func(cacheKey string) (*T, time.Duration)
PrefetchThreshold int
PrefetchExpires time.Duration
PrefetchMaxItemsCount int
OnPrefetchAfterPut OnAfterPutCallback
OnPrefetchEntryReloaded OnEntryReloadedCallback
OnPrefetchCacheHit OnCacheHitCallback
}
type PrefetchingCacheOption[T any] func(c *PrefetchingExpiringLRUCache[cacheValue[T]])
func NewPrefetchingCache[T any](options PrefetchingOptions[T]) *PrefetchingExpiringLRUCache[T] {
pc := &PrefetchingExpiringLRUCache[T]{
prefetchingNameCache: NewCache[atomic.Uint32](Options{
CleanupInterval: time.Minute,
MaxSize: uint(options.PrefetchMaxItemsCount),
OnAfterPutFn: options.OnPrefetchAfterPut,
}),
prefetchExpires: options.PrefetchExpires,
prefetchThreshold: options.PrefetchThreshold,
reloadFn: options.ReloadFn,
onPrefetchEntryReloaded: options.OnPrefetchEntryReloaded,
onPrefetchCacheHit: options.OnPrefetchCacheHit,
}
pc.cache = NewCacheWithOnExpired[cacheValue[T]](options.Options, pc.onExpired)
return pc
}
// check if a cache entry should be prefetched: was queried > threshold in the time window
func (e *PrefetchingExpiringLRUCache[T]) shouldPrefetch(cacheKey string) bool {
if e.prefetchThreshold == 0 {
return true
}
cnt, _ := e.prefetchingNameCache.Get(cacheKey)
return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold)
}
func (e *PrefetchingExpiringLRUCache[T]) onExpired(cacheKey string) (val *cacheValue[T], ttl time.Duration) {
if e.shouldPrefetch(cacheKey) {
loadedVal, ttl := e.reloadFn(cacheKey)
if loadedVal != nil {
if e.onPrefetchEntryReloaded != nil {
e.onPrefetchEntryReloaded(cacheKey)
}
return &cacheValue[T]{loadedVal, true}, ttl
}
}
return nil, 0
}
func (e *PrefetchingExpiringLRUCache[T]) trackCacheKeyQueryCount(cacheKey string) {
var x *atomic.Uint32
if x, _ = e.prefetchingNameCache.Get(cacheKey); x == nil {
x = &atomic.Uint32{}
}
x.Add(1)
e.prefetchingNameCache.Put(cacheKey, x, e.prefetchExpires)
}
func (e *PrefetchingExpiringLRUCache[T]) Put(key string, val *T, expiration time.Duration) {
e.cache.Put(key, &cacheValue[T]{element: val, prefetch: false}, expiration)
}
// Get returns the value of cached entry with remained TTL. If entry is not cached, returns nil
func (e *PrefetchingExpiringLRUCache[T]) Get(key string) (val *T, expiration time.Duration) {
e.trackCacheKeyQueryCount(key)
res, exp := e.cache.Get(key)
if res == nil {
return nil, exp
}
if e.onPrefetchCacheHit != nil && res.prefetch {
// Hit from prefetch cache
e.onPrefetchCacheHit(key)
}
return res.element, exp
}
// TotalCount returns the total count of valid (not expired) elements
func (e *PrefetchingExpiringLRUCache[T]) TotalCount() int {
return e.cache.TotalCount()
}
// Clear removes all cache entries
func (e *PrefetchingExpiringLRUCache[T]) Clear() {
e.cache.Clear()
}

View File

@ -0,0 +1,183 @@
package expirationcache
import (
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Prefetching expiration cache", func() {
Describe("Basic operations", func() {
When("string cache was created", func() {
It("Initial cache should be empty", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
Expect(cache.TotalCount()).Should(Equal(0))
})
It("Initial cache should not contain any elements", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
val, expiration := cache.Get("key1")
Expect(val).Should(BeNil())
Expect(expiration).Should(Equal(time.Duration(0)))
})
It("Should work as cache (basic operations)", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
val, expiration := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
Expect(cache.TotalCount()).Should(Equal(1))
cache.Clear()
Expect(cache.TotalCount()).Should(Equal(0))
})
})
Context("Prefetching", func() {
It("Should prefetch element", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
v := "v2"
return &v, 50 * time.Millisecond
},
})
By("put a value and get it again", func() {
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
val, expiration := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
Expect(cache.TotalCount()).Should(Equal(1))
})
By("Get value twice to trigger prefetching", func() {
val, _ := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
Eventually(func(g Gomega) {
val, _ := cache.Get("key1")
g.Expect(val).Should(HaveValue(Equal("v2")))
}).Should(Succeed())
})
})
It("Should not prefetch element", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
v := "v2"
return &v, 50 * time.Millisecond
},
})
By("put a value and get it again", func() {
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
val, expiration := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
Expect(cache.TotalCount()).Should(Equal(1))
})
By("Wait for expiration -> the entry should not be prefetched, threshold was not reached", func() {
Eventually(func(g Gomega) {
val, _ := cache.Get("key1")
g.Expect(val).Should(BeNil())
}, "5s", "500ms").Should(Succeed())
})
})
It("With default config (threshold = 0) should always prefetch", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
ReloadFn: func(cacheKey string) (*string, time.Duration) {
v := "v2"
return &v, 50 * time.Millisecond
},
})
By("put a value and get it again", func() {
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
val, expiration := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
})
By("Should return new prefetched value after expiration", func() {
Eventually(func(g Gomega) {
val, _ := cache.Get("key1")
g.Expect(val).Should(HaveValue(Equal("v2")))
}, "5s").Should(Succeed())
})
})
It("Should execute hook functions", func() {
onPrefetchAfterPutChannel := make(chan int, 10)
onPrefetchEntryReloaded := make(chan string, 10)
onnPrefetchCacheHit := make(chan string, 10)
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
v := "v2"
return &v, 50 * time.Millisecond
},
OnPrefetchAfterPut: func(newSize int) { onPrefetchAfterPutChannel <- newSize },
OnPrefetchEntryReloaded: func(key string) { onPrefetchEntryReloaded <- key },
OnPrefetchCacheHit: func(key string) { onnPrefetchCacheHit <- key },
})
By("put a value", func() {
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
Expect(onPrefetchAfterPutChannel).Should(Not(Receive()))
Expect(onPrefetchEntryReloaded).Should(Not(Receive()))
Expect(onnPrefetchCacheHit).Should(Not(Receive()))
})
By("get a value 3 times to trigger prefetching", func() {
// first get
cache.Get("key1")
Expect(onPrefetchAfterPutChannel).Should(Receive(Equal(1)))
Expect(onnPrefetchCacheHit).Should(Not(Receive()))
Expect(onPrefetchEntryReloaded).Should(Not(Receive()))
// secont get
val, _ := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
// third get -> this should trigger prefetching after expiration
cache.Get("key1")
// reload was executed
Eventually(onPrefetchEntryReloaded).Should(Receive(Equal("key1")))
Expect(onnPrefetchCacheHit).Should(Not(Receive()))
// has new value
Eventually(func(g Gomega) {
val, _ := cache.Get("key1")
g.Expect(val).Should(HaveValue(Equal("v2")))
}, "5s").Should(Succeed())
// prefetch hit
Eventually(onnPrefetchCacheHit).Should(Receive(Equal("key1")))
})
})
})
})
})

View File

@ -20,10 +20,10 @@ const (
// CachingPrefetchCacheHit fires if a query result was found in the prefetch cache, Parameter: domain name // CachingPrefetchCacheHit fires if a query result was found in the prefetch cache, Parameter: domain name
CachingPrefetchCacheHit = "caching:prefetchHit" CachingPrefetchCacheHit = "caching:prefetchHit"
// CachingResultCacheHit fires, if a query result was found in the cache, Parameter: domain name // CachingResultCacheHit fires, if a query result was found in the cache
CachingResultCacheHit = "caching:cacheHit" CachingResultCacheHit = "caching:cacheHit"
// CachingResultCacheMiss fires, if a query result was not found in the cache, Parameter: domain name // CachingResultCacheMiss fires, if a query result was not found in the cache
CachingResultCacheMiss = "caching:cacheMiss" CachingResultCacheMiss = "caching:cacheMiss"
// CachingDomainsToPrefetchCountChanged fires, if a number of domains being prefetched changed, Parameter: new count // CachingDomainsToPrefetchCountChanged fires, if a number of domains being prefetched changed, Parameter: new count

View File

@ -603,10 +603,11 @@ func (r *BlockingResolver) initFQDNIPCache() {
identifiers = append(identifiers, identifier) identifiers = append(identifiers, identifier)
} }
r.fqdnIPCache = expirationcache.NewCache(expirationcache.WithCleanUpInterval[[]net.IP](defaultBlockingCleanUpInterval), r.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](expirationcache.Options{
expirationcache.WithOnExpiredFn(func(key string) (val *[]net.IP, ttl time.Duration) { CleanupInterval: defaultBlockingCleanUpInterval,
return r.queryForFQIdentifierIPs(key) }, func(key string) (val *[]net.IP, ttl time.Duration) {
})) return r.queryForFQIdentifierIPs(key)
})
for _, identifier := range identifiers { for _, identifier := range identifiers {
if isFQDN(identifier) { if isFQDN(identifier) {

View File

@ -29,15 +29,9 @@ type CachingResolver struct {
emitMetricEvents bool // disabled by Bootstrap emitMetricEvents bool // disabled by Bootstrap
resultCache expirationcache.ExpiringCache[cacheValue] resultCache expirationcache.ExpiringCache[dns.Msg]
prefetchingNameCache expirationcache.ExpiringCache[int]
redisClient *redis.Client
}
// cacheValue includes query answer and prefetch flag redisClient *redis.Client
type cacheValue struct {
resultMsg *dns.Msg
prefetch bool
} }
// NewCachingResolver creates a new resolver instance // NewCachingResolver creates a new resolver instance
@ -65,73 +59,76 @@ func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetri
} }
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) { func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
cleanupOption := expirationcache.WithCleanUpInterval[cacheValue](defaultCachingCleanUpInterval) options := expirationcache.Options{
maxSizeOption := expirationcache.WithMaxSize[cacheValue](uint(cfg.MaxItemsCount)) CleanupInterval: defaultCachingCleanUpInterval,
MaxSize: uint(cfg.MaxItemsCount),
OnCacheHitFn: func(key string) {
c.publishMetricsIfEnabled(evt.CachingResultCacheHit, key)
},
OnCacheMissFn: func(key string) {
c.publishMetricsIfEnabled(evt.CachingResultCacheMiss, key)
},
OnAfterPutFn: func(newSize int) {
c.publishMetricsIfEnabled(evt.CachingResultCacheChanged, newSize)
},
}
if cfg.Prefetching { if cfg.Prefetching {
c.prefetchingNameCache = expirationcache.NewCache( prefetchingOptions := expirationcache.PrefetchingOptions[dns.Msg]{
expirationcache.WithCleanUpInterval[int](time.Minute), Options: options,
expirationcache.WithMaxSize[int](uint(cfg.PrefetchMaxItemsCount)), PrefetchExpires: time.Duration(cfg.PrefetchExpires),
) PrefetchThreshold: cfg.PrefetchThreshold,
PrefetchMaxItemsCount: cfg.PrefetchMaxItemsCount,
ReloadFn: c.reloadCacheEntry,
OnPrefetchAfterPut: func(newSize int) {
c.publishMetricsIfEnabled(evt.CachingDomainsToPrefetchCountChanged, newSize)
},
OnPrefetchEntryReloaded: func(key string) {
c.publishMetricsIfEnabled(evt.CachingDomainPrefetched, key)
},
OnPrefetchCacheHit: func(key string) {
c.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, key)
},
}
c.resultCache = expirationcache.NewCache( c.resultCache = expirationcache.NewPrefetchingCache(prefetchingOptions)
cleanupOption,
maxSizeOption,
expirationcache.WithOnExpiredFn(c.onExpired),
)
} else { } else {
c.resultCache = expirationcache.NewCache(cleanupOption, maxSizeOption) c.resultCache = expirationcache.NewCache[dns.Msg](options)
} }
} }
func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*dns.Msg, time.Duration) {
qType, domainName := util.ExtractCacheKey(cacheKey)
logger := r.log()
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger)
response, err := r.next.Resolve(req)
if err == nil {
if response.Res.Rcode == dns.RcodeSuccess {
return response.Res, r.adjustTTLs(response.Res.Answer)
}
} else {
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
}
return nil, 0
}
func setupRedisCacheSubscriber(c *CachingResolver) { func setupRedisCacheSubscriber(c *CachingResolver) {
go func() { go func() {
for rc := range c.redisClient.CacheChannel { for rc := range c.redisClient.CacheChannel {
if rc != nil { if rc != nil {
c.log().Debug("Received key from redis: ", rc.Key) c.log().Debug("Received key from redis: ", rc.Key)
ttl := c.adjustTTLs(rc.Response.Res.Answer) ttl := c.adjustTTLs(rc.Response.Res.Answer)
c.putInCache(rc.Key, rc.Response, ttl, false, false) c.putInCache(rc.Key, rc.Response, ttl, false)
} }
} }
}() }()
} }
// check if domain was queried > threshold in the time window
func (r *CachingResolver) shouldPrefetch(cacheKey string) bool {
if r.cfg.PrefetchThreshold == 0 {
return true
}
cnt, _ := r.prefetchingNameCache.Get(cacheKey)
return cnt != nil && *cnt > r.cfg.PrefetchThreshold
}
func (r *CachingResolver) onExpired(cacheKey string) (val *cacheValue, ttl time.Duration) {
qType, domainName := util.ExtractCacheKey(cacheKey)
if r.shouldPrefetch(cacheKey) {
logger := r.log()
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger)
response, err := r.next.Resolve(req)
if err == nil {
if response.Res.Rcode == dns.RcodeSuccess {
r.publishMetricsIfEnabled(evt.CachingDomainPrefetched, domainName)
return &cacheValue{response.Res, true}, r.adjustTTLs(response.Res.Answer)
}
} else {
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
}
}
return nil, 0
}
// LogConfig implements `config.Configurable`. // LogConfig implements `config.Configurable`.
func (r *CachingResolver) LogConfig(logger *logrus.Entry) { func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
r.cfg.LogConfig(logger) r.cfg.LogConfig(logger)
@ -155,23 +152,14 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain) cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain)
logger := logger.WithField("domain", util.Obfuscate(domain)) logger := logger.WithField("domain", util.Obfuscate(domain))
r.trackQueryDomainNameCount(domain, cacheKey, logger)
val, ttl := r.resultCache.Get(cacheKey) val, ttl := r.resultCache.Get(cacheKey)
if val != nil { if val != nil {
logger.Debug("domain is cached") logger.Debug("domain is cached")
r.publishMetricsIfEnabled(evt.CachingResultCacheHit, domain) resp := val.Copy()
if val.prefetch {
// Hit from prefetch cache
r.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, domain)
}
resp := val.resultMsg.Copy()
resp.SetReply(request.Req) resp.SetReply(request.Req)
resp.Rcode = val.resultMsg.Rcode resp.Rcode = val.Rcode
// Adjust TTL // Adjust TTL
setTTLInCachedResponse(resp, ttl) setTTLInCachedResponse(resp, ttl)
@ -183,14 +171,12 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
} }
r.publishMetricsIfEnabled(evt.CachingResultCacheMiss, domain)
logger.WithField("next_resolver", Name(r.next)).Debug("not in cache: go to next resolver") logger.WithField("next_resolver", Name(r.next)).Debug("not in cache: go to next resolver")
response, err = r.next.Resolve(request) response, err = r.next.Resolve(request)
if err == nil { if err == nil {
cacheTTL := r.adjustTTLs(response.Res.Answer) cacheTTL := r.adjustTTLs(response.Res.Answer)
r.putInCache(cacheKey, response, cacheTTL, false, true) r.putInCache(cacheKey, response, cacheTTL, true)
} }
} }
@ -209,22 +195,6 @@ func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) {
} }
} }
func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, logger *logrus.Entry) {
if r.prefetchingNameCache != nil {
var domainCount int
if x, _ := r.prefetchingNameCache.Get(cacheKey); x != nil {
domainCount = *x
}
domainCount++
r.prefetchingNameCache.Put(cacheKey, &domainCount, r.cfg.PrefetchExpires.ToDuration())
totalCount := r.prefetchingNameCache.TotalCount()
logger.Debugf("domain '%s' was requested %d times, "+
"total cache size: %d", util.Obfuscate(domain), domainCount, totalCount)
r.publishMetricsIfEnabled(evt.CachingDomainsToPrefetchCountChanged, totalCount)
}
}
// removes EDNS OPT records from message // removes EDNS OPT records from message
func removeEdns0Extra(msg *dns.Msg) { func removeEdns0Extra(msg *dns.Msg) {
if len(msg.Extra) > 0 { if len(msg.Extra) > 0 {
@ -246,7 +216,7 @@ func shouldBeCached(msg *dns.Msg) bool {
} }
func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, ttl time.Duration, func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, ttl time.Duration,
prefetch, publish bool, publish bool,
) { ) {
respCopy := response.Res.Copy() respCopy := response.Res.Copy()
@ -255,16 +225,14 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response,
if response.Res.Rcode == dns.RcodeSuccess && shouldBeCached(response.Res) { if response.Res.Rcode == dns.RcodeSuccess && shouldBeCached(response.Res) {
// put value into cache // put value into cache
r.resultCache.Put(cacheKey, &cacheValue{respCopy, prefetch}, ttl) r.resultCache.Put(cacheKey, respCopy, ttl)
} else if response.Res.Rcode == dns.RcodeNameError { } else if response.Res.Rcode == dns.RcodeNameError {
if r.cfg.CacheTimeNegative.IsAboveZero() { if r.cfg.CacheTimeNegative.IsAboveZero() {
// put negative cache if result code is NXDOMAIN // put negative cache if result code is NXDOMAIN
r.resultCache.Put(cacheKey, &cacheValue{respCopy, prefetch}, r.cfg.CacheTimeNegative.ToDuration()) r.resultCache.Put(cacheKey, respCopy, r.cfg.CacheTimeNegative.ToDuration())
} }
} }
r.publishMetricsIfEnabled(evt.CachingResultCacheChanged, r.resultCache.TotalCount())
if publish && r.redisClient != nil { if publish && r.redisClient != nil {
res := *respCopy res := *respCopy
r.redisClient.PublishCache(cacheKey, &res) r.redisClient.PublishCache(cacheKey, &res)

View File

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/0xERR0R/blocky/cache/expirationcache"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/evt" . "github.com/0xERR0R/blocky/evt"
. "github.com/0xERR0R/blocky/helpertest" . "github.com/0xERR0R/blocky/helpertest"
@ -81,18 +80,15 @@ var _ = Describe("CachingResolver", func() {
// prepare resolver, set smaller caching times for testing // prepare resolver, set smaller caching times for testing
prefetchThreshold := 5 prefetchThreshold := 5
configureCaches(sut, &sutConfig) configureCaches(sut, &sutConfig)
sut.resultCache = expirationcache.NewCache(
expirationcache.WithCleanUpInterval[cacheValue](100*time.Millisecond),
expirationcache.WithOnExpiredFn(sut.onExpired))
domainPrefetched := make(chan string, 1) domainPrefetched := make(chan bool, 1)
prefetchHitDomain := make(chan string, 1) prefetchHitDomain := make(chan bool, 1)
prefetchedCnt := make(chan int, 1) prefetchedCnt := make(chan int, 1)
Expect(Bus().SubscribeOnce(CachingPrefetchCacheHit, func(domain string) { Expect(Bus().SubscribeOnce(CachingPrefetchCacheHit, func(domain string) {
prefetchHitDomain <- domain prefetchHitDomain <- true
})).Should(Succeed()) })).Should(Succeed())
Expect(Bus().SubscribeOnce(CachingDomainPrefetched, func(domain string) { Expect(Bus().SubscribeOnce(CachingDomainPrefetched, func(domain string) {
domainPrefetched <- domain domainPrefetched <- true
})).Should(Succeed()) })).Should(Succeed())
Expect(Bus().SubscribeOnce(CachingDomainsToPrefetchCountChanged, func(cnt int) { Expect(Bus().SubscribeOnce(CachingDomainsToPrefetchCountChanged, func(cnt int) {
@ -115,7 +111,7 @@ var _ = Describe("CachingResolver", func() {
} }
// now is this domain prefetched // now is this domain prefetched
Eventually(domainPrefetched, "4s").Should(Receive(Equal("example.com"))) Eventually(domainPrefetched, "10s").Should(Receive(Equal(true)))
// and it should hit from prefetch cache // and it should hit from prefetch cache
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(newRequest("example.com.", A))).
@ -125,16 +121,7 @@ var _ = Describe("CachingResolver", func() {
HaveReturnCode(dns.RcodeSuccess), HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.", A, "123.122.121.120"), BeDNSRecord("example.com.", A, "123.122.121.120"),
HaveTTL(BeNumerically("<=", 2)))) HaveTTL(BeNumerically("<=", 2))))
Eventually(prefetchHitDomain, "4s").Should(Receive(Equal("example.com"))) Eventually(prefetchHitDomain, "10s").Should(Receive(Equal(true)))
})
When("threshold is 0", func() {
BeforeEach(func() {
sutConfig.PrefetchThreshold = 0
})
It("should always prefetch", func() {
Expect(sut.shouldPrefetch("domain.tld")).Should(BeTrue())
})
}) })
}) })
When("caching with default values is enabled", func() { When("caching with default values is enabled", func() {
@ -204,9 +191,9 @@ var _ = Describe("CachingResolver", func() {
It("should cache response and use response's TTL", func() { It("should cache response and use response's TTL", func() {
By("first request", func() { By("first request", func() {
domain := make(chan string, 1) domain := make(chan bool, 1)
_ = Bus().SubscribeOnce(CachingResultCacheMiss, func(d string) { _ = Bus().SubscribeOnce(CachingResultCacheMiss, func(d string) {
domain <- d domain <- true
}) })
totalCacheCount := make(chan int, 1) totalCacheCount := make(chan int, 1)
@ -223,15 +210,15 @@ var _ = Describe("CachingResolver", func() {
Expect(m.Calls).Should(HaveLen(1)) Expect(m.Calls).Should(HaveLen(1))
Expect(domain).Should(Receive(Equal("example.com"))) Expect(domain).Should(Receive(Equal(true)))
Expect(totalCacheCount).Should(Receive(Equal(1))) Expect(totalCacheCount).Should(Receive(Equal(1)))
}) })
By("second request", func() { By("second request", func() {
Eventually(func(g Gomega) { Eventually(func(g Gomega) {
domain := make(chan string, 1) domain := make(chan bool, 1)
_ = Bus().SubscribeOnce(CachingResultCacheHit, func(d string) { _ = Bus().SubscribeOnce(CachingResultCacheHit, func(d string) {
domain <- d domain <- true
}) })
g.Expect(sut.Resolve(newRequest("example.com.", A))). g.Expect(sut.Resolve(newRequest("example.com.", A))).
@ -246,7 +233,7 @@ var _ = Describe("CachingResolver", func() {
// still one call to upstream // still one call to upstream
g.Expect(m.Calls).Should(HaveLen(1)) g.Expect(m.Calls).Should(HaveLen(1))
g.Expect(domain).Should(Receive(Equal("example.com"))) g.Expect(domain).Should(Receive(Equal(true)))
}, "1s").Should(Succeed()) }, "1s").Should(Succeed())
}) })
}) })

View File

@ -41,7 +41,9 @@ func NewClientNamesResolver(
configurable: withConfig(&cfg), configurable: withConfig(&cfg),
typed: withType("client_names"), typed: withType("client_names"),
cache: expirationcache.NewCache(expirationcache.WithCleanUpInterval[[]string](time.Hour)), cache: expirationcache.NewCache[[]string](expirationcache.Options{
CleanupInterval: time.Hour,
}),
externalResolver: r, externalResolver: r,
} }