mirror of https://github.com/0xERR0R/blocky.git
chore(refactor): refactor cache implementation (#1174)
* chore(refactor): refactor cache implementation * chore: use atomic.Uint32 as prefetch names query count Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> --------- Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>
This commit is contained in:
parent
d76740ea51
commit
497bd0d0fd
|
@ -19,15 +19,18 @@ type element[T any] struct {
|
||||||
type ExpiringLRUCache[T any] struct {
|
type ExpiringLRUCache[T any] struct {
|
||||||
cleanUpInterval time.Duration
|
cleanUpInterval time.Duration
|
||||||
preExpirationFn OnExpirationCallback[T]
|
preExpirationFn OnExpirationCallback[T]
|
||||||
|
onCacheHit OnCacheHitCallback
|
||||||
|
onCacheMiss OnCacheMissCallback
|
||||||
|
onAfterPut OnAfterPutCallback
|
||||||
lru *lru.Cache
|
lru *lru.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
type CacheOption[T any] func(c *ExpiringLRUCache[T])
|
type Options struct {
|
||||||
|
OnCacheHitFn OnCacheHitCallback
|
||||||
func WithCleanUpInterval[T any](d time.Duration) CacheOption[T] {
|
OnCacheMissFn OnCacheMissCallback
|
||||||
return func(e *ExpiringLRUCache[T]) {
|
OnAfterPutFn OnAfterPutCallback
|
||||||
e.cleanUpInterval = d
|
CleanupInterval time.Duration
|
||||||
}
|
MaxSize uint
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnExpirationCallback will be called just before an element gets expired and will
|
// OnExpirationCallback will be called just before an element gets expired and will
|
||||||
|
@ -35,33 +38,56 @@ func WithCleanUpInterval[T any](d time.Duration) CacheOption[T] {
|
||||||
// element in the cache or nil to remove it
|
// element in the cache or nil to remove it
|
||||||
type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration)
|
type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration)
|
||||||
|
|
||||||
func WithOnExpiredFn[T any](fn OnExpirationCallback[T]) CacheOption[T] {
|
// OnCacheHitCallback will be called on cache get if entry was found
|
||||||
return func(c *ExpiringLRUCache[T]) {
|
type OnCacheHitCallback func(key string)
|
||||||
c.preExpirationFn = fn
|
|
||||||
}
|
// OnCacheMissCallback will be called on cache get and entry was not found
|
||||||
|
type OnCacheMissCallback func(key string)
|
||||||
|
|
||||||
|
// OnAfterPutCallback will be called after put, receives new element count as parameter
|
||||||
|
type OnAfterPutCallback func(newSize int)
|
||||||
|
|
||||||
|
func NewCache[T any](options Options) *ExpiringLRUCache[T] {
|
||||||
|
return NewCacheWithOnExpired[T](options, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithMaxSize[T any](size uint) CacheOption[T] {
|
func NewCacheWithOnExpired[T any](options Options,
|
||||||
return func(c *ExpiringLRUCache[T]) {
|
onExpirationFn OnExpirationCallback[T],
|
||||||
if size > 0 {
|
) *ExpiringLRUCache[T] {
|
||||||
l, _ := lru.New(int(size))
|
|
||||||
c.lru = l
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCache[T any](options ...CacheOption[T]) *ExpiringLRUCache[T] {
|
|
||||||
l, _ := lru.New(defaultSize)
|
l, _ := lru.New(defaultSize)
|
||||||
c := &ExpiringLRUCache[T]{
|
c := &ExpiringLRUCache[T]{
|
||||||
cleanUpInterval: defaultCleanUpInterval,
|
cleanUpInterval: defaultCleanUpInterval,
|
||||||
preExpirationFn: func(key string) (val *T, ttl time.Duration) {
|
preExpirationFn: func(key string) (val *T, ttl time.Duration) {
|
||||||
return nil, 0
|
return nil, 0
|
||||||
},
|
},
|
||||||
lru: l,
|
onCacheHit: func(key string) {},
|
||||||
|
onCacheMiss: func(key string) {},
|
||||||
|
lru: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, opt := range options {
|
if options.CleanupInterval > 0 {
|
||||||
opt(c)
|
c.cleanUpInterval = options.CleanupInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.MaxSize > 0 {
|
||||||
|
l, _ := lru.New(int(options.MaxSize))
|
||||||
|
c.lru = l
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.OnAfterPutFn != nil {
|
||||||
|
c.onAfterPut = options.OnAfterPutFn
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.OnCacheHitFn != nil {
|
||||||
|
c.onCacheHit = options.OnCacheHitFn
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.OnCacheMissFn != nil {
|
||||||
|
c.onCacheMiss = options.OnCacheMissFn
|
||||||
|
}
|
||||||
|
|
||||||
|
if onExpirationFn != nil {
|
||||||
|
c.preExpirationFn = onExpirationFn
|
||||||
}
|
}
|
||||||
|
|
||||||
go periodicCleanup(c)
|
go periodicCleanup(c)
|
||||||
|
@ -122,15 +148,23 @@ func (e *ExpiringLRUCache[T]) Put(key string, val *T, ttl time.Duration) {
|
||||||
val: val,
|
val: val,
|
||||||
expiresEpochMs: expiresEpochMs,
|
expiresEpochMs: expiresEpochMs,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if e.onAfterPut != nil {
|
||||||
|
e.onAfterPut(e.lru.Len())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ExpiringLRUCache[T]) Get(key string) (val *T, ttl time.Duration) {
|
func (e *ExpiringLRUCache[T]) Get(key string) (val *T, ttl time.Duration) {
|
||||||
el, found := e.lru.Get(key)
|
el, found := e.lru.Get(key)
|
||||||
|
|
||||||
if found {
|
if found {
|
||||||
|
e.onCacheHit(key)
|
||||||
|
|
||||||
return el.(*element[T]).val, calculateRemainTTL(el.(*element[T]).expiresEpochMs)
|
return el.(*element[T]).val, calculateRemainTTL(el.(*element[T]).expiresEpochMs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.onCacheMiss(key)
|
||||||
|
|
||||||
return nil, 0
|
return nil, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,11 +11,11 @@ var _ = Describe("Expiration cache", func() {
|
||||||
Describe("Basic operations", func() {
|
Describe("Basic operations", func() {
|
||||||
When("string cache was created", func() {
|
When("string cache was created", func() {
|
||||||
It("Initial cache should be empty", func() {
|
It("Initial cache should be empty", func() {
|
||||||
cache := NewCache[string]()
|
cache := NewCache[string](Options{})
|
||||||
Expect(cache.TotalCount()).Should(Equal(0))
|
Expect(cache.TotalCount()).Should(Equal(0))
|
||||||
})
|
})
|
||||||
It("Initial cache should not contain any elements", func() {
|
It("Initial cache should not contain any elements", func() {
|
||||||
cache := NewCache[string]()
|
cache := NewCache[string](Options{})
|
||||||
val, expiration := cache.Get("key1")
|
val, expiration := cache.Get("key1")
|
||||||
Expect(val).Should(BeNil())
|
Expect(val).Should(BeNil())
|
||||||
Expect(expiration).Should(Equal(time.Duration(0)))
|
Expect(expiration).Should(Equal(time.Duration(0)))
|
||||||
|
@ -23,7 +23,7 @@ var _ = Describe("Expiration cache", func() {
|
||||||
})
|
})
|
||||||
When("Put new value with positive TTL", func() {
|
When("Put new value with positive TTL", func() {
|
||||||
It("Should return the value before element expires", func() {
|
It("Should return the value before element expires", func() {
|
||||||
cache := NewCache(WithCleanUpInterval[string](100 * time.Millisecond))
|
cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond})
|
||||||
v := "v1"
|
v := "v1"
|
||||||
cache.Put("key1", &v, 50*time.Millisecond)
|
cache.Put("key1", &v, 50*time.Millisecond)
|
||||||
val, expiration := cache.Get("key1")
|
val, expiration := cache.Get("key1")
|
||||||
|
@ -33,7 +33,7 @@ var _ = Describe("Expiration cache", func() {
|
||||||
Expect(cache.TotalCount()).Should(Equal(1))
|
Expect(cache.TotalCount()).Should(Equal(1))
|
||||||
})
|
})
|
||||||
It("Should return nil after expiration", func() {
|
It("Should return nil after expiration", func() {
|
||||||
cache := NewCache(WithCleanUpInterval[string](100 * time.Millisecond))
|
cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond})
|
||||||
v := "v1"
|
v := "v1"
|
||||||
cache.Put("key1", &v, 50*time.Millisecond)
|
cache.Put("key1", &v, 50*time.Millisecond)
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ var _ = Describe("Expiration cache", func() {
|
||||||
})
|
})
|
||||||
When("Put new value without expiration", func() {
|
When("Put new value without expiration", func() {
|
||||||
It("Should not cache the value", func() {
|
It("Should not cache the value", func() {
|
||||||
cache := NewCache(WithCleanUpInterval[string](50 * time.Millisecond))
|
cache := NewCache[string](Options{CleanupInterval: 50 * time.Millisecond})
|
||||||
v := "x"
|
v := "x"
|
||||||
cache.Put("key1", &v, 0)
|
cache.Put("key1", &v, 0)
|
||||||
val, expiration := cache.Get("key1")
|
val, expiration := cache.Get("key1")
|
||||||
|
@ -63,7 +63,7 @@ var _ = Describe("Expiration cache", func() {
|
||||||
})
|
})
|
||||||
When("Put updated value", func() {
|
When("Put updated value", func() {
|
||||||
It("Should return updated value", func() {
|
It("Should return updated value", func() {
|
||||||
cache := NewCache[string]()
|
cache := NewCache[string](Options{})
|
||||||
v1 := "v1"
|
v1 := "v1"
|
||||||
v2 := "v2"
|
v2 := "v2"
|
||||||
cache.Put("key1", &v1, 50*time.Millisecond)
|
cache.Put("key1", &v1, 50*time.Millisecond)
|
||||||
|
@ -79,7 +79,7 @@ var _ = Describe("Expiration cache", func() {
|
||||||
})
|
})
|
||||||
When("Purging after usage", func() {
|
When("Purging after usage", func() {
|
||||||
It("Should be empty after purge", func() {
|
It("Should be empty after purge", func() {
|
||||||
cache := NewCache[string]()
|
cache := NewCache[string](Options{})
|
||||||
v1 := "y"
|
v1 := "y"
|
||||||
cache.Put("key1", &v1, time.Second)
|
cache.Put("key1", &v1, time.Second)
|
||||||
|
|
||||||
|
@ -91,15 +91,62 @@ var _ = Describe("Expiration cache", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
Describe("Hook functions", func() {
|
||||||
|
When("Hook functions are defined", func() {
|
||||||
|
It("should call each hook function", func() {
|
||||||
|
onCacheHitChannel := make(chan string, 10)
|
||||||
|
onCacheMissChannel := make(chan string, 10)
|
||||||
|
onAfterPutChannel := make(chan int, 10)
|
||||||
|
cache := NewCache[string](Options{
|
||||||
|
OnCacheHitFn: func(key string) {
|
||||||
|
onCacheHitChannel <- key
|
||||||
|
},
|
||||||
|
OnCacheMissFn: func(key string) {
|
||||||
|
onCacheMissChannel <- key
|
||||||
|
},
|
||||||
|
OnAfterPutFn: func(newSize int) {
|
||||||
|
onAfterPutChannel <- newSize
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
By("Get non existing value", func() {
|
||||||
|
val, _ := cache.Get("notExists")
|
||||||
|
Expect(val).Should(BeNil())
|
||||||
|
|
||||||
|
Expect(onCacheMissChannel).Should(Receive(Equal("notExists")))
|
||||||
|
Expect(onCacheHitChannel).Should(Not(Receive()))
|
||||||
|
Expect(onAfterPutChannel).Should(Not(Receive()))
|
||||||
|
})
|
||||||
|
|
||||||
|
By("Put new cache entry", func() {
|
||||||
|
v1 := "v1"
|
||||||
|
cache.Put("key1", &v1, time.Second)
|
||||||
|
Expect(onCacheMissChannel).Should(Not(Receive()))
|
||||||
|
Expect(onCacheMissChannel).Should(Not(Receive()))
|
||||||
|
Expect(onAfterPutChannel).Should(Receive(Equal(1)))
|
||||||
|
})
|
||||||
|
|
||||||
|
By("Get existing value", func() {
|
||||||
|
val, _ := cache.Get("key1")
|
||||||
|
Expect(val).Should(HaveValue(Equal("v1")))
|
||||||
|
|
||||||
|
Expect(onCacheMissChannel).Should(Not(Receive()))
|
||||||
|
Expect(onCacheHitChannel).Should(Receive(Equal("key1")))
|
||||||
|
Expect(onAfterPutChannel).Should(Not(Receive()))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
Describe("preExpiration function", func() {
|
Describe("preExpiration function", func() {
|
||||||
When(" function is defined", func() {
|
When("function is defined", func() {
|
||||||
It("should update the value and TTL if function returns values", func() {
|
It("should update the value and TTL if function returns values", func() {
|
||||||
fn := func(key string) (val *string, ttl time.Duration) {
|
fn := func(key string) (val *string, ttl time.Duration) {
|
||||||
v2 := "v2"
|
v2 := "v2"
|
||||||
|
|
||||||
return &v2, time.Second
|
return &v2, time.Second
|
||||||
}
|
}
|
||||||
cache := NewCache(WithOnExpiredFn(fn))
|
|
||||||
|
cache := NewCacheWithOnExpired[string](Options{}, fn)
|
||||||
v1 := "v1"
|
v1 := "v1"
|
||||||
cache.Put("key1", &v1, 50*time.Millisecond)
|
cache.Put("key1", &v1, 50*time.Millisecond)
|
||||||
|
|
||||||
|
@ -118,7 +165,7 @@ var _ = Describe("Expiration cache", func() {
|
||||||
|
|
||||||
return &v2, time.Second
|
return &v2, time.Second
|
||||||
}
|
}
|
||||||
cache := NewCache(WithOnExpiredFn(fn))
|
cache := NewCacheWithOnExpired[string](Options{}, fn)
|
||||||
v1 := "somval"
|
v1 := "somval"
|
||||||
cache.Put("key1", &v1, time.Millisecond)
|
cache.Put("key1", &v1, time.Millisecond)
|
||||||
|
|
||||||
|
@ -139,7 +186,7 @@ var _ = Describe("Expiration cache", func() {
|
||||||
fn := func(key string) (val *string, ttl time.Duration) {
|
fn := func(key string) (val *string, ttl time.Duration) {
|
||||||
return nil, 0
|
return nil, 0
|
||||||
}
|
}
|
||||||
cache := NewCache(WithCleanUpInterval[string](100*time.Millisecond), WithOnExpiredFn(fn))
|
cache := NewCacheWithOnExpired[string](Options{CleanupInterval: 100 * time.Microsecond}, fn)
|
||||||
v1 := "z"
|
v1 := "z"
|
||||||
cache.Put("key1", &v1, 50*time.Millisecond)
|
cache.Put("key1", &v1, 50*time.Millisecond)
|
||||||
|
|
||||||
|
@ -152,7 +199,7 @@ var _ = Describe("Expiration cache", func() {
|
||||||
Describe("LRU behaviour", func() {
|
Describe("LRU behaviour", func() {
|
||||||
When("Defined max size is reached", func() {
|
When("Defined max size is reached", func() {
|
||||||
It("should remove old elements", func() {
|
It("should remove old elements", func() {
|
||||||
cache := NewCache(WithMaxSize[string](3))
|
cache := NewCache[string](Options{MaxSize: 3})
|
||||||
|
|
||||||
v1 := "val1"
|
v1 := "val1"
|
||||||
v2 := "val2"
|
v2 := "val2"
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
package expirationcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PrefetchingExpiringLRUCache[T any] struct {
|
||||||
|
cache ExpiringCache[cacheValue[T]]
|
||||||
|
prefetchingNameCache ExpiringCache[atomic.Uint32]
|
||||||
|
reloadFn ReloadEntryFn[T]
|
||||||
|
prefetchThreshold int
|
||||||
|
prefetchExpires time.Duration
|
||||||
|
onPrefetchEntryReloaded OnEntryReloadedCallback
|
||||||
|
onPrefetchCacheHit OnCacheHitCallback
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheValue[T any] struct {
|
||||||
|
element *T
|
||||||
|
prefetch bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnEntryReloadedCallback will be called if a prefetched entry is reloaded
|
||||||
|
type OnEntryReloadedCallback func(key string)
|
||||||
|
|
||||||
|
// ReloadEntryFn reloads a prefetched entry by key
|
||||||
|
type ReloadEntryFn[T any] func(key string) (*T, time.Duration)
|
||||||
|
|
||||||
|
type PrefetchingOptions[T any] struct {
|
||||||
|
Options
|
||||||
|
ReloadFn func(cacheKey string) (*T, time.Duration)
|
||||||
|
PrefetchThreshold int
|
||||||
|
PrefetchExpires time.Duration
|
||||||
|
PrefetchMaxItemsCount int
|
||||||
|
OnPrefetchAfterPut OnAfterPutCallback
|
||||||
|
OnPrefetchEntryReloaded OnEntryReloadedCallback
|
||||||
|
OnPrefetchCacheHit OnCacheHitCallback
|
||||||
|
}
|
||||||
|
|
||||||
|
type PrefetchingCacheOption[T any] func(c *PrefetchingExpiringLRUCache[cacheValue[T]])
|
||||||
|
|
||||||
|
func NewPrefetchingCache[T any](options PrefetchingOptions[T]) *PrefetchingExpiringLRUCache[T] {
|
||||||
|
pc := &PrefetchingExpiringLRUCache[T]{
|
||||||
|
prefetchingNameCache: NewCache[atomic.Uint32](Options{
|
||||||
|
CleanupInterval: time.Minute,
|
||||||
|
MaxSize: uint(options.PrefetchMaxItemsCount),
|
||||||
|
OnAfterPutFn: options.OnPrefetchAfterPut,
|
||||||
|
}),
|
||||||
|
prefetchExpires: options.PrefetchExpires,
|
||||||
|
prefetchThreshold: options.PrefetchThreshold,
|
||||||
|
reloadFn: options.ReloadFn,
|
||||||
|
onPrefetchEntryReloaded: options.OnPrefetchEntryReloaded,
|
||||||
|
onPrefetchCacheHit: options.OnPrefetchCacheHit,
|
||||||
|
}
|
||||||
|
|
||||||
|
pc.cache = NewCacheWithOnExpired[cacheValue[T]](options.Options, pc.onExpired)
|
||||||
|
|
||||||
|
return pc
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if a cache entry should be prefetched: was queried > threshold in the time window
|
||||||
|
func (e *PrefetchingExpiringLRUCache[T]) shouldPrefetch(cacheKey string) bool {
|
||||||
|
if e.prefetchThreshold == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
cnt, _ := e.prefetchingNameCache.Get(cacheKey)
|
||||||
|
|
||||||
|
return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *PrefetchingExpiringLRUCache[T]) onExpired(cacheKey string) (val *cacheValue[T], ttl time.Duration) {
|
||||||
|
if e.shouldPrefetch(cacheKey) {
|
||||||
|
loadedVal, ttl := e.reloadFn(cacheKey)
|
||||||
|
if loadedVal != nil {
|
||||||
|
if e.onPrefetchEntryReloaded != nil {
|
||||||
|
e.onPrefetchEntryReloaded(cacheKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cacheValue[T]{loadedVal, true}, ttl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *PrefetchingExpiringLRUCache[T]) trackCacheKeyQueryCount(cacheKey string) {
|
||||||
|
var x *atomic.Uint32
|
||||||
|
if x, _ = e.prefetchingNameCache.Get(cacheKey); x == nil {
|
||||||
|
x = &atomic.Uint32{}
|
||||||
|
}
|
||||||
|
|
||||||
|
x.Add(1)
|
||||||
|
e.prefetchingNameCache.Put(cacheKey, x, e.prefetchExpires)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *PrefetchingExpiringLRUCache[T]) Put(key string, val *T, expiration time.Duration) {
|
||||||
|
e.cache.Put(key, &cacheValue[T]{element: val, prefetch: false}, expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the value of cached entry with remained TTL. If entry is not cached, returns nil
|
||||||
|
func (e *PrefetchingExpiringLRUCache[T]) Get(key string) (val *T, expiration time.Duration) {
|
||||||
|
e.trackCacheKeyQueryCount(key)
|
||||||
|
|
||||||
|
res, exp := e.cache.Get(key)
|
||||||
|
|
||||||
|
if res == nil {
|
||||||
|
return nil, exp
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.onPrefetchCacheHit != nil && res.prefetch {
|
||||||
|
// Hit from prefetch cache
|
||||||
|
e.onPrefetchCacheHit(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.element, exp
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalCount returns the total count of valid (not expired) elements
|
||||||
|
func (e *PrefetchingExpiringLRUCache[T]) TotalCount() int {
|
||||||
|
return e.cache.TotalCount()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all cache entries
|
||||||
|
func (e *PrefetchingExpiringLRUCache[T]) Clear() {
|
||||||
|
e.cache.Clear()
|
||||||
|
}
|
|
@ -0,0 +1,183 @@
|
||||||
|
package expirationcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Prefetching expiration cache", func() {
|
||||||
|
Describe("Basic operations", func() {
|
||||||
|
When("string cache was created", func() {
|
||||||
|
It("Initial cache should be empty", func() {
|
||||||
|
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
|
||||||
|
Expect(cache.TotalCount()).Should(Equal(0))
|
||||||
|
})
|
||||||
|
It("Initial cache should not contain any elements", func() {
|
||||||
|
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
|
||||||
|
val, expiration := cache.Get("key1")
|
||||||
|
Expect(val).Should(BeNil())
|
||||||
|
Expect(expiration).Should(Equal(time.Duration(0)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Should work as cache (basic operations)", func() {
|
||||||
|
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
|
||||||
|
v := "v1"
|
||||||
|
cache.Put("key1", &v, 50*time.Millisecond)
|
||||||
|
|
||||||
|
val, expiration := cache.Get("key1")
|
||||||
|
Expect(val).Should(HaveValue(Equal("v1")))
|
||||||
|
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
|
||||||
|
|
||||||
|
Expect(cache.TotalCount()).Should(Equal(1))
|
||||||
|
|
||||||
|
cache.Clear()
|
||||||
|
|
||||||
|
Expect(cache.TotalCount()).Should(Equal(0))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
Context("Prefetching", func() {
|
||||||
|
It("Should prefetch element", func() {
|
||||||
|
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
|
||||||
|
Options: Options{
|
||||||
|
CleanupInterval: 100 * time.Millisecond,
|
||||||
|
},
|
||||||
|
PrefetchThreshold: 2,
|
||||||
|
PrefetchExpires: 100 * time.Millisecond,
|
||||||
|
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
||||||
|
v := "v2"
|
||||||
|
|
||||||
|
return &v, 50 * time.Millisecond
|
||||||
|
},
|
||||||
|
})
|
||||||
|
By("put a value and get it again", func() {
|
||||||
|
v := "v1"
|
||||||
|
cache.Put("key1", &v, 50*time.Millisecond)
|
||||||
|
val, expiration := cache.Get("key1")
|
||||||
|
Expect(val).Should(HaveValue(Equal("v1")))
|
||||||
|
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
|
||||||
|
|
||||||
|
Expect(cache.TotalCount()).Should(Equal(1))
|
||||||
|
})
|
||||||
|
By("Get value twice to trigger prefetching", func() {
|
||||||
|
val, _ := cache.Get("key1")
|
||||||
|
Expect(val).Should(HaveValue(Equal("v1")))
|
||||||
|
|
||||||
|
Eventually(func(g Gomega) {
|
||||||
|
val, _ := cache.Get("key1")
|
||||||
|
g.Expect(val).Should(HaveValue(Equal("v2")))
|
||||||
|
}).Should(Succeed())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
It("Should not prefetch element", func() {
|
||||||
|
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
|
||||||
|
Options: Options{
|
||||||
|
CleanupInterval: 100 * time.Millisecond,
|
||||||
|
},
|
||||||
|
PrefetchThreshold: 2,
|
||||||
|
PrefetchExpires: 100 * time.Millisecond,
|
||||||
|
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
||||||
|
v := "v2"
|
||||||
|
|
||||||
|
return &v, 50 * time.Millisecond
|
||||||
|
},
|
||||||
|
})
|
||||||
|
By("put a value and get it again", func() {
|
||||||
|
v := "v1"
|
||||||
|
cache.Put("key1", &v, 50*time.Millisecond)
|
||||||
|
val, expiration := cache.Get("key1")
|
||||||
|
Expect(val).Should(HaveValue(Equal("v1")))
|
||||||
|
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
|
||||||
|
|
||||||
|
Expect(cache.TotalCount()).Should(Equal(1))
|
||||||
|
})
|
||||||
|
By("Wait for expiration -> the entry should not be prefetched, threshold was not reached", func() {
|
||||||
|
Eventually(func(g Gomega) {
|
||||||
|
val, _ := cache.Get("key1")
|
||||||
|
g.Expect(val).Should(BeNil())
|
||||||
|
}, "5s", "500ms").Should(Succeed())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
It("With default config (threshold = 0) should always prefetch", func() {
|
||||||
|
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
|
||||||
|
Options: Options{
|
||||||
|
CleanupInterval: 100 * time.Millisecond,
|
||||||
|
},
|
||||||
|
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
||||||
|
v := "v2"
|
||||||
|
|
||||||
|
return &v, 50 * time.Millisecond
|
||||||
|
},
|
||||||
|
})
|
||||||
|
By("put a value and get it again", func() {
|
||||||
|
v := "v1"
|
||||||
|
cache.Put("key1", &v, 50*time.Millisecond)
|
||||||
|
val, expiration := cache.Get("key1")
|
||||||
|
Expect(val).Should(HaveValue(Equal("v1")))
|
||||||
|
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
|
||||||
|
})
|
||||||
|
By("Should return new prefetched value after expiration", func() {
|
||||||
|
Eventually(func(g Gomega) {
|
||||||
|
val, _ := cache.Get("key1")
|
||||||
|
g.Expect(val).Should(HaveValue(Equal("v2")))
|
||||||
|
}, "5s").Should(Succeed())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
It("Should execute hook functions", func() {
|
||||||
|
onPrefetchAfterPutChannel := make(chan int, 10)
|
||||||
|
onPrefetchEntryReloaded := make(chan string, 10)
|
||||||
|
onnPrefetchCacheHit := make(chan string, 10)
|
||||||
|
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
|
||||||
|
Options: Options{
|
||||||
|
CleanupInterval: 100 * time.Millisecond,
|
||||||
|
},
|
||||||
|
PrefetchThreshold: 2,
|
||||||
|
PrefetchExpires: 100 * time.Millisecond,
|
||||||
|
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
||||||
|
v := "v2"
|
||||||
|
|
||||||
|
return &v, 50 * time.Millisecond
|
||||||
|
},
|
||||||
|
OnPrefetchAfterPut: func(newSize int) { onPrefetchAfterPutChannel <- newSize },
|
||||||
|
OnPrefetchEntryReloaded: func(key string) { onPrefetchEntryReloaded <- key },
|
||||||
|
OnPrefetchCacheHit: func(key string) { onnPrefetchCacheHit <- key },
|
||||||
|
})
|
||||||
|
By("put a value", func() {
|
||||||
|
v := "v1"
|
||||||
|
cache.Put("key1", &v, 50*time.Millisecond)
|
||||||
|
Expect(onPrefetchAfterPutChannel).Should(Not(Receive()))
|
||||||
|
Expect(onPrefetchEntryReloaded).Should(Not(Receive()))
|
||||||
|
Expect(onnPrefetchCacheHit).Should(Not(Receive()))
|
||||||
|
})
|
||||||
|
By("get a value 3 times to trigger prefetching", func() {
|
||||||
|
// first get
|
||||||
|
cache.Get("key1")
|
||||||
|
|
||||||
|
Expect(onPrefetchAfterPutChannel).Should(Receive(Equal(1)))
|
||||||
|
Expect(onnPrefetchCacheHit).Should(Not(Receive()))
|
||||||
|
Expect(onPrefetchEntryReloaded).Should(Not(Receive()))
|
||||||
|
|
||||||
|
// secont get
|
||||||
|
val, _ := cache.Get("key1")
|
||||||
|
Expect(val).Should(HaveValue(Equal("v1")))
|
||||||
|
|
||||||
|
// third get -> this should trigger prefetching after expiration
|
||||||
|
cache.Get("key1")
|
||||||
|
|
||||||
|
// reload was executed
|
||||||
|
Eventually(onPrefetchEntryReloaded).Should(Receive(Equal("key1")))
|
||||||
|
Expect(onnPrefetchCacheHit).Should(Not(Receive()))
|
||||||
|
// has new value
|
||||||
|
Eventually(func(g Gomega) {
|
||||||
|
val, _ := cache.Get("key1")
|
||||||
|
g.Expect(val).Should(HaveValue(Equal("v2")))
|
||||||
|
}, "5s").Should(Succeed())
|
||||||
|
|
||||||
|
// prefetch hit
|
||||||
|
Eventually(onnPrefetchCacheHit).Should(Receive(Equal("key1")))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
|
@ -20,10 +20,10 @@ const (
|
||||||
// CachingPrefetchCacheHit fires if a query result was found in the prefetch cache, Parameter: domain name
|
// CachingPrefetchCacheHit fires if a query result was found in the prefetch cache, Parameter: domain name
|
||||||
CachingPrefetchCacheHit = "caching:prefetchHit"
|
CachingPrefetchCacheHit = "caching:prefetchHit"
|
||||||
|
|
||||||
// CachingResultCacheHit fires, if a query result was found in the cache, Parameter: domain name
|
// CachingResultCacheHit fires, if a query result was found in the cache
|
||||||
CachingResultCacheHit = "caching:cacheHit"
|
CachingResultCacheHit = "caching:cacheHit"
|
||||||
|
|
||||||
// CachingResultCacheMiss fires, if a query result was not found in the cache, Parameter: domain name
|
// CachingResultCacheMiss fires, if a query result was not found in the cache
|
||||||
CachingResultCacheMiss = "caching:cacheMiss"
|
CachingResultCacheMiss = "caching:cacheMiss"
|
||||||
|
|
||||||
// CachingDomainsToPrefetchCountChanged fires, if a number of domains being prefetched changed, Parameter: new count
|
// CachingDomainsToPrefetchCountChanged fires, if a number of domains being prefetched changed, Parameter: new count
|
||||||
|
|
|
@ -603,10 +603,11 @@ func (r *BlockingResolver) initFQDNIPCache() {
|
||||||
identifiers = append(identifiers, identifier)
|
identifiers = append(identifiers, identifier)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.fqdnIPCache = expirationcache.NewCache(expirationcache.WithCleanUpInterval[[]net.IP](defaultBlockingCleanUpInterval),
|
r.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](expirationcache.Options{
|
||||||
expirationcache.WithOnExpiredFn(func(key string) (val *[]net.IP, ttl time.Duration) {
|
CleanupInterval: defaultBlockingCleanUpInterval,
|
||||||
return r.queryForFQIdentifierIPs(key)
|
}, func(key string) (val *[]net.IP, ttl time.Duration) {
|
||||||
}))
|
return r.queryForFQIdentifierIPs(key)
|
||||||
|
})
|
||||||
|
|
||||||
for _, identifier := range identifiers {
|
for _, identifier := range identifiers {
|
||||||
if isFQDN(identifier) {
|
if isFQDN(identifier) {
|
||||||
|
|
|
@ -29,15 +29,9 @@ type CachingResolver struct {
|
||||||
|
|
||||||
emitMetricEvents bool // disabled by Bootstrap
|
emitMetricEvents bool // disabled by Bootstrap
|
||||||
|
|
||||||
resultCache expirationcache.ExpiringCache[cacheValue]
|
resultCache expirationcache.ExpiringCache[dns.Msg]
|
||||||
prefetchingNameCache expirationcache.ExpiringCache[int]
|
|
||||||
redisClient *redis.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
// cacheValue includes query answer and prefetch flag
|
redisClient *redis.Client
|
||||||
type cacheValue struct {
|
|
||||||
resultMsg *dns.Msg
|
|
||||||
prefetch bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCachingResolver creates a new resolver instance
|
// NewCachingResolver creates a new resolver instance
|
||||||
|
@ -65,73 +59,76 @@ func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetri
|
||||||
}
|
}
|
||||||
|
|
||||||
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
|
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
|
||||||
cleanupOption := expirationcache.WithCleanUpInterval[cacheValue](defaultCachingCleanUpInterval)
|
options := expirationcache.Options{
|
||||||
maxSizeOption := expirationcache.WithMaxSize[cacheValue](uint(cfg.MaxItemsCount))
|
CleanupInterval: defaultCachingCleanUpInterval,
|
||||||
|
MaxSize: uint(cfg.MaxItemsCount),
|
||||||
|
OnCacheHitFn: func(key string) {
|
||||||
|
c.publishMetricsIfEnabled(evt.CachingResultCacheHit, key)
|
||||||
|
},
|
||||||
|
OnCacheMissFn: func(key string) {
|
||||||
|
c.publishMetricsIfEnabled(evt.CachingResultCacheMiss, key)
|
||||||
|
},
|
||||||
|
OnAfterPutFn: func(newSize int) {
|
||||||
|
c.publishMetricsIfEnabled(evt.CachingResultCacheChanged, newSize)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
if cfg.Prefetching {
|
if cfg.Prefetching {
|
||||||
c.prefetchingNameCache = expirationcache.NewCache(
|
prefetchingOptions := expirationcache.PrefetchingOptions[dns.Msg]{
|
||||||
expirationcache.WithCleanUpInterval[int](time.Minute),
|
Options: options,
|
||||||
expirationcache.WithMaxSize[int](uint(cfg.PrefetchMaxItemsCount)),
|
PrefetchExpires: time.Duration(cfg.PrefetchExpires),
|
||||||
)
|
PrefetchThreshold: cfg.PrefetchThreshold,
|
||||||
|
PrefetchMaxItemsCount: cfg.PrefetchMaxItemsCount,
|
||||||
|
ReloadFn: c.reloadCacheEntry,
|
||||||
|
OnPrefetchAfterPut: func(newSize int) {
|
||||||
|
c.publishMetricsIfEnabled(evt.CachingDomainsToPrefetchCountChanged, newSize)
|
||||||
|
},
|
||||||
|
OnPrefetchEntryReloaded: func(key string) {
|
||||||
|
c.publishMetricsIfEnabled(evt.CachingDomainPrefetched, key)
|
||||||
|
},
|
||||||
|
OnPrefetchCacheHit: func(key string) {
|
||||||
|
c.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, key)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
c.resultCache = expirationcache.NewCache(
|
c.resultCache = expirationcache.NewPrefetchingCache(prefetchingOptions)
|
||||||
cleanupOption,
|
|
||||||
maxSizeOption,
|
|
||||||
expirationcache.WithOnExpiredFn(c.onExpired),
|
|
||||||
)
|
|
||||||
} else {
|
} else {
|
||||||
c.resultCache = expirationcache.NewCache(cleanupOption, maxSizeOption)
|
c.resultCache = expirationcache.NewCache[dns.Msg](options)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*dns.Msg, time.Duration) {
|
||||||
|
qType, domainName := util.ExtractCacheKey(cacheKey)
|
||||||
|
logger := r.log()
|
||||||
|
|
||||||
|
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
|
||||||
|
|
||||||
|
req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger)
|
||||||
|
response, err := r.next.Resolve(req)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
if response.Res.Rcode == dns.RcodeSuccess {
|
||||||
|
return response.Res, r.adjustTTLs(response.Res.Answer)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, 0
|
||||||
|
}
|
||||||
|
|
||||||
func setupRedisCacheSubscriber(c *CachingResolver) {
|
func setupRedisCacheSubscriber(c *CachingResolver) {
|
||||||
go func() {
|
go func() {
|
||||||
for rc := range c.redisClient.CacheChannel {
|
for rc := range c.redisClient.CacheChannel {
|
||||||
if rc != nil {
|
if rc != nil {
|
||||||
c.log().Debug("Received key from redis: ", rc.Key)
|
c.log().Debug("Received key from redis: ", rc.Key)
|
||||||
ttl := c.adjustTTLs(rc.Response.Res.Answer)
|
ttl := c.adjustTTLs(rc.Response.Res.Answer)
|
||||||
c.putInCache(rc.Key, rc.Response, ttl, false, false)
|
c.putInCache(rc.Key, rc.Response, ttl, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if domain was queried > threshold in the time window
|
|
||||||
func (r *CachingResolver) shouldPrefetch(cacheKey string) bool {
|
|
||||||
if r.cfg.PrefetchThreshold == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
cnt, _ := r.prefetchingNameCache.Get(cacheKey)
|
|
||||||
|
|
||||||
return cnt != nil && *cnt > r.cfg.PrefetchThreshold
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *CachingResolver) onExpired(cacheKey string) (val *cacheValue, ttl time.Duration) {
|
|
||||||
qType, domainName := util.ExtractCacheKey(cacheKey)
|
|
||||||
|
|
||||||
if r.shouldPrefetch(cacheKey) {
|
|
||||||
logger := r.log()
|
|
||||||
|
|
||||||
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
|
|
||||||
|
|
||||||
req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger)
|
|
||||||
response, err := r.next.Resolve(req)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
if response.Res.Rcode == dns.RcodeSuccess {
|
|
||||||
r.publishMetricsIfEnabled(evt.CachingDomainPrefetched, domainName)
|
|
||||||
|
|
||||||
return &cacheValue{response.Res, true}, r.adjustTTLs(response.Res.Answer)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// LogConfig implements `config.Configurable`.
|
// LogConfig implements `config.Configurable`.
|
||||||
func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
|
func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
|
||||||
r.cfg.LogConfig(logger)
|
r.cfg.LogConfig(logger)
|
||||||
|
@ -155,23 +152,14 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
|
||||||
cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain)
|
cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain)
|
||||||
logger := logger.WithField("domain", util.Obfuscate(domain))
|
logger := logger.WithField("domain", util.Obfuscate(domain))
|
||||||
|
|
||||||
r.trackQueryDomainNameCount(domain, cacheKey, logger)
|
|
||||||
|
|
||||||
val, ttl := r.resultCache.Get(cacheKey)
|
val, ttl := r.resultCache.Get(cacheKey)
|
||||||
|
|
||||||
if val != nil {
|
if val != nil {
|
||||||
logger.Debug("domain is cached")
|
logger.Debug("domain is cached")
|
||||||
|
|
||||||
r.publishMetricsIfEnabled(evt.CachingResultCacheHit, domain)
|
resp := val.Copy()
|
||||||
|
|
||||||
if val.prefetch {
|
|
||||||
// Hit from prefetch cache
|
|
||||||
r.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := val.resultMsg.Copy()
|
|
||||||
resp.SetReply(request.Req)
|
resp.SetReply(request.Req)
|
||||||
resp.Rcode = val.resultMsg.Rcode
|
resp.Rcode = val.Rcode
|
||||||
|
|
||||||
// Adjust TTL
|
// Adjust TTL
|
||||||
setTTLInCachedResponse(resp, ttl)
|
setTTLInCachedResponse(resp, ttl)
|
||||||
|
@ -183,14 +171,12 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
|
||||||
return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
|
return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
r.publishMetricsIfEnabled(evt.CachingResultCacheMiss, domain)
|
|
||||||
|
|
||||||
logger.WithField("next_resolver", Name(r.next)).Debug("not in cache: go to next resolver")
|
logger.WithField("next_resolver", Name(r.next)).Debug("not in cache: go to next resolver")
|
||||||
response, err = r.next.Resolve(request)
|
response, err = r.next.Resolve(request)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cacheTTL := r.adjustTTLs(response.Res.Answer)
|
cacheTTL := r.adjustTTLs(response.Res.Answer)
|
||||||
r.putInCache(cacheKey, response, cacheTTL, false, true)
|
r.putInCache(cacheKey, response, cacheTTL, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,22 +195,6 @@ func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, logger *logrus.Entry) {
|
|
||||||
if r.prefetchingNameCache != nil {
|
|
||||||
var domainCount int
|
|
||||||
if x, _ := r.prefetchingNameCache.Get(cacheKey); x != nil {
|
|
||||||
domainCount = *x
|
|
||||||
}
|
|
||||||
domainCount++
|
|
||||||
r.prefetchingNameCache.Put(cacheKey, &domainCount, r.cfg.PrefetchExpires.ToDuration())
|
|
||||||
totalCount := r.prefetchingNameCache.TotalCount()
|
|
||||||
|
|
||||||
logger.Debugf("domain '%s' was requested %d times, "+
|
|
||||||
"total cache size: %d", util.Obfuscate(domain), domainCount, totalCount)
|
|
||||||
r.publishMetricsIfEnabled(evt.CachingDomainsToPrefetchCountChanged, totalCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// removes EDNS OPT records from message
|
// removes EDNS OPT records from message
|
||||||
func removeEdns0Extra(msg *dns.Msg) {
|
func removeEdns0Extra(msg *dns.Msg) {
|
||||||
if len(msg.Extra) > 0 {
|
if len(msg.Extra) > 0 {
|
||||||
|
@ -246,7 +216,7 @@ func shouldBeCached(msg *dns.Msg) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, ttl time.Duration,
|
func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, ttl time.Duration,
|
||||||
prefetch, publish bool,
|
publish bool,
|
||||||
) {
|
) {
|
||||||
respCopy := response.Res.Copy()
|
respCopy := response.Res.Copy()
|
||||||
|
|
||||||
|
@ -255,16 +225,14 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response,
|
||||||
|
|
||||||
if response.Res.Rcode == dns.RcodeSuccess && shouldBeCached(response.Res) {
|
if response.Res.Rcode == dns.RcodeSuccess && shouldBeCached(response.Res) {
|
||||||
// put value into cache
|
// put value into cache
|
||||||
r.resultCache.Put(cacheKey, &cacheValue{respCopy, prefetch}, ttl)
|
r.resultCache.Put(cacheKey, respCopy, ttl)
|
||||||
} else if response.Res.Rcode == dns.RcodeNameError {
|
} else if response.Res.Rcode == dns.RcodeNameError {
|
||||||
if r.cfg.CacheTimeNegative.IsAboveZero() {
|
if r.cfg.CacheTimeNegative.IsAboveZero() {
|
||||||
// put negative cache if result code is NXDOMAIN
|
// put negative cache if result code is NXDOMAIN
|
||||||
r.resultCache.Put(cacheKey, &cacheValue{respCopy, prefetch}, r.cfg.CacheTimeNegative.ToDuration())
|
r.resultCache.Put(cacheKey, respCopy, r.cfg.CacheTimeNegative.ToDuration())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.publishMetricsIfEnabled(evt.CachingResultCacheChanged, r.resultCache.TotalCount())
|
|
||||||
|
|
||||||
if publish && r.redisClient != nil {
|
if publish && r.redisClient != nil {
|
||||||
res := *respCopy
|
res := *respCopy
|
||||||
r.redisClient.PublishCache(cacheKey, &res)
|
r.redisClient.PublishCache(cacheKey, &res)
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/cache/expirationcache"
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
. "github.com/0xERR0R/blocky/evt"
|
. "github.com/0xERR0R/blocky/evt"
|
||||||
. "github.com/0xERR0R/blocky/helpertest"
|
. "github.com/0xERR0R/blocky/helpertest"
|
||||||
|
@ -81,18 +80,15 @@ var _ = Describe("CachingResolver", func() {
|
||||||
// prepare resolver, set smaller caching times for testing
|
// prepare resolver, set smaller caching times for testing
|
||||||
prefetchThreshold := 5
|
prefetchThreshold := 5
|
||||||
configureCaches(sut, &sutConfig)
|
configureCaches(sut, &sutConfig)
|
||||||
sut.resultCache = expirationcache.NewCache(
|
|
||||||
expirationcache.WithCleanUpInterval[cacheValue](100*time.Millisecond),
|
|
||||||
expirationcache.WithOnExpiredFn(sut.onExpired))
|
|
||||||
|
|
||||||
domainPrefetched := make(chan string, 1)
|
domainPrefetched := make(chan bool, 1)
|
||||||
prefetchHitDomain := make(chan string, 1)
|
prefetchHitDomain := make(chan bool, 1)
|
||||||
prefetchedCnt := make(chan int, 1)
|
prefetchedCnt := make(chan int, 1)
|
||||||
Expect(Bus().SubscribeOnce(CachingPrefetchCacheHit, func(domain string) {
|
Expect(Bus().SubscribeOnce(CachingPrefetchCacheHit, func(domain string) {
|
||||||
prefetchHitDomain <- domain
|
prefetchHitDomain <- true
|
||||||
})).Should(Succeed())
|
})).Should(Succeed())
|
||||||
Expect(Bus().SubscribeOnce(CachingDomainPrefetched, func(domain string) {
|
Expect(Bus().SubscribeOnce(CachingDomainPrefetched, func(domain string) {
|
||||||
domainPrefetched <- domain
|
domainPrefetched <- true
|
||||||
})).Should(Succeed())
|
})).Should(Succeed())
|
||||||
|
|
||||||
Expect(Bus().SubscribeOnce(CachingDomainsToPrefetchCountChanged, func(cnt int) {
|
Expect(Bus().SubscribeOnce(CachingDomainsToPrefetchCountChanged, func(cnt int) {
|
||||||
|
@ -115,7 +111,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// now is this domain prefetched
|
// now is this domain prefetched
|
||||||
Eventually(domainPrefetched, "4s").Should(Receive(Equal("example.com")))
|
Eventually(domainPrefetched, "10s").Should(Receive(Equal(true)))
|
||||||
|
|
||||||
// and it should hit from prefetch cache
|
// and it should hit from prefetch cache
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||||
|
@ -125,16 +121,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
HaveReturnCode(dns.RcodeSuccess),
|
HaveReturnCode(dns.RcodeSuccess),
|
||||||
BeDNSRecord("example.com.", A, "123.122.121.120"),
|
BeDNSRecord("example.com.", A, "123.122.121.120"),
|
||||||
HaveTTL(BeNumerically("<=", 2))))
|
HaveTTL(BeNumerically("<=", 2))))
|
||||||
Eventually(prefetchHitDomain, "4s").Should(Receive(Equal("example.com")))
|
Eventually(prefetchHitDomain, "10s").Should(Receive(Equal(true)))
|
||||||
})
|
|
||||||
When("threshold is 0", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
sutConfig.PrefetchThreshold = 0
|
|
||||||
})
|
|
||||||
|
|
||||||
It("should always prefetch", func() {
|
|
||||||
Expect(sut.shouldPrefetch("domain.tld")).Should(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
When("caching with default values is enabled", func() {
|
When("caching with default values is enabled", func() {
|
||||||
|
@ -204,9 +191,9 @@ var _ = Describe("CachingResolver", func() {
|
||||||
|
|
||||||
It("should cache response and use response's TTL", func() {
|
It("should cache response and use response's TTL", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
domain := make(chan string, 1)
|
domain := make(chan bool, 1)
|
||||||
_ = Bus().SubscribeOnce(CachingResultCacheMiss, func(d string) {
|
_ = Bus().SubscribeOnce(CachingResultCacheMiss, func(d string) {
|
||||||
domain <- d
|
domain <- true
|
||||||
})
|
})
|
||||||
|
|
||||||
totalCacheCount := make(chan int, 1)
|
totalCacheCount := make(chan int, 1)
|
||||||
|
@ -223,15 +210,15 @@ var _ = Describe("CachingResolver", func() {
|
||||||
|
|
||||||
Expect(m.Calls).Should(HaveLen(1))
|
Expect(m.Calls).Should(HaveLen(1))
|
||||||
|
|
||||||
Expect(domain).Should(Receive(Equal("example.com")))
|
Expect(domain).Should(Receive(Equal(true)))
|
||||||
Expect(totalCacheCount).Should(Receive(Equal(1)))
|
Expect(totalCacheCount).Should(Receive(Equal(1)))
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Eventually(func(g Gomega) {
|
Eventually(func(g Gomega) {
|
||||||
domain := make(chan string, 1)
|
domain := make(chan bool, 1)
|
||||||
_ = Bus().SubscribeOnce(CachingResultCacheHit, func(d string) {
|
_ = Bus().SubscribeOnce(CachingResultCacheHit, func(d string) {
|
||||||
domain <- d
|
domain <- true
|
||||||
})
|
})
|
||||||
|
|
||||||
g.Expect(sut.Resolve(newRequest("example.com.", A))).
|
g.Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||||
|
@ -246,7 +233,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
// still one call to upstream
|
// still one call to upstream
|
||||||
g.Expect(m.Calls).Should(HaveLen(1))
|
g.Expect(m.Calls).Should(HaveLen(1))
|
||||||
|
|
||||||
g.Expect(domain).Should(Receive(Equal("example.com")))
|
g.Expect(domain).Should(Receive(Equal(true)))
|
||||||
}, "1s").Should(Succeed())
|
}, "1s").Should(Succeed())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -41,7 +41,9 @@ func NewClientNamesResolver(
|
||||||
configurable: withConfig(&cfg),
|
configurable: withConfig(&cfg),
|
||||||
typed: withType("client_names"),
|
typed: withType("client_names"),
|
||||||
|
|
||||||
cache: expirationcache.NewCache(expirationcache.WithCleanUpInterval[[]string](time.Hour)),
|
cache: expirationcache.NewCache[[]string](expirationcache.Options{
|
||||||
|
CleanupInterval: time.Hour,
|
||||||
|
}),
|
||||||
externalResolver: r,
|
externalResolver: r,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue