mirror of https://github.com/0xERR0R/blocky.git
chore(refactor): refactor cache implementation (#1174)
* chore(refactor): refactor cache implementation * chore: use atomic.Uint32 as prefetch names query count Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> --------- Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>
This commit is contained in:
parent
d76740ea51
commit
497bd0d0fd
|
@ -19,15 +19,18 @@ type element[T any] struct {
|
|||
type ExpiringLRUCache[T any] struct {
|
||||
cleanUpInterval time.Duration
|
||||
preExpirationFn OnExpirationCallback[T]
|
||||
onCacheHit OnCacheHitCallback
|
||||
onCacheMiss OnCacheMissCallback
|
||||
onAfterPut OnAfterPutCallback
|
||||
lru *lru.Cache
|
||||
}
|
||||
|
||||
type CacheOption[T any] func(c *ExpiringLRUCache[T])
|
||||
|
||||
func WithCleanUpInterval[T any](d time.Duration) CacheOption[T] {
|
||||
return func(e *ExpiringLRUCache[T]) {
|
||||
e.cleanUpInterval = d
|
||||
}
|
||||
type Options struct {
|
||||
OnCacheHitFn OnCacheHitCallback
|
||||
OnCacheMissFn OnCacheMissCallback
|
||||
OnAfterPutFn OnAfterPutCallback
|
||||
CleanupInterval time.Duration
|
||||
MaxSize uint
|
||||
}
|
||||
|
||||
// OnExpirationCallback will be called just before an element gets expired and will
|
||||
|
@ -35,33 +38,56 @@ func WithCleanUpInterval[T any](d time.Duration) CacheOption[T] {
|
|||
// element in the cache or nil to remove it
|
||||
type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration)
|
||||
|
||||
func WithOnExpiredFn[T any](fn OnExpirationCallback[T]) CacheOption[T] {
|
||||
return func(c *ExpiringLRUCache[T]) {
|
||||
c.preExpirationFn = fn
|
||||
}
|
||||
// OnCacheHitCallback will be called on cache get if entry was found
|
||||
type OnCacheHitCallback func(key string)
|
||||
|
||||
// OnCacheMissCallback will be called on cache get and entry was not found
|
||||
type OnCacheMissCallback func(key string)
|
||||
|
||||
// OnAfterPutCallback will be called after put, receives new element count as parameter
|
||||
type OnAfterPutCallback func(newSize int)
|
||||
|
||||
func NewCache[T any](options Options) *ExpiringLRUCache[T] {
|
||||
return NewCacheWithOnExpired[T](options, nil)
|
||||
}
|
||||
|
||||
func WithMaxSize[T any](size uint) CacheOption[T] {
|
||||
return func(c *ExpiringLRUCache[T]) {
|
||||
if size > 0 {
|
||||
l, _ := lru.New(int(size))
|
||||
c.lru = l
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewCache[T any](options ...CacheOption[T]) *ExpiringLRUCache[T] {
|
||||
func NewCacheWithOnExpired[T any](options Options,
|
||||
onExpirationFn OnExpirationCallback[T],
|
||||
) *ExpiringLRUCache[T] {
|
||||
l, _ := lru.New(defaultSize)
|
||||
c := &ExpiringLRUCache[T]{
|
||||
cleanUpInterval: defaultCleanUpInterval,
|
||||
preExpirationFn: func(key string) (val *T, ttl time.Duration) {
|
||||
return nil, 0
|
||||
},
|
||||
lru: l,
|
||||
onCacheHit: func(key string) {},
|
||||
onCacheMiss: func(key string) {},
|
||||
lru: l,
|
||||
}
|
||||
|
||||
for _, opt := range options {
|
||||
opt(c)
|
||||
if options.CleanupInterval > 0 {
|
||||
c.cleanUpInterval = options.CleanupInterval
|
||||
}
|
||||
|
||||
if options.MaxSize > 0 {
|
||||
l, _ := lru.New(int(options.MaxSize))
|
||||
c.lru = l
|
||||
}
|
||||
|
||||
if options.OnAfterPutFn != nil {
|
||||
c.onAfterPut = options.OnAfterPutFn
|
||||
}
|
||||
|
||||
if options.OnCacheHitFn != nil {
|
||||
c.onCacheHit = options.OnCacheHitFn
|
||||
}
|
||||
|
||||
if options.OnCacheMissFn != nil {
|
||||
c.onCacheMiss = options.OnCacheMissFn
|
||||
}
|
||||
|
||||
if onExpirationFn != nil {
|
||||
c.preExpirationFn = onExpirationFn
|
||||
}
|
||||
|
||||
go periodicCleanup(c)
|
||||
|
@ -122,15 +148,23 @@ func (e *ExpiringLRUCache[T]) Put(key string, val *T, ttl time.Duration) {
|
|||
val: val,
|
||||
expiresEpochMs: expiresEpochMs,
|
||||
})
|
||||
|
||||
if e.onAfterPut != nil {
|
||||
e.onAfterPut(e.lru.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ExpiringLRUCache[T]) Get(key string) (val *T, ttl time.Duration) {
|
||||
el, found := e.lru.Get(key)
|
||||
|
||||
if found {
|
||||
e.onCacheHit(key)
|
||||
|
||||
return el.(*element[T]).val, calculateRemainTTL(el.(*element[T]).expiresEpochMs)
|
||||
}
|
||||
|
||||
e.onCacheMiss(key)
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
|
|
|
@ -11,11 +11,11 @@ var _ = Describe("Expiration cache", func() {
|
|||
Describe("Basic operations", func() {
|
||||
When("string cache was created", func() {
|
||||
It("Initial cache should be empty", func() {
|
||||
cache := NewCache[string]()
|
||||
cache := NewCache[string](Options{})
|
||||
Expect(cache.TotalCount()).Should(Equal(0))
|
||||
})
|
||||
It("Initial cache should not contain any elements", func() {
|
||||
cache := NewCache[string]()
|
||||
cache := NewCache[string](Options{})
|
||||
val, expiration := cache.Get("key1")
|
||||
Expect(val).Should(BeNil())
|
||||
Expect(expiration).Should(Equal(time.Duration(0)))
|
||||
|
@ -23,7 +23,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
})
|
||||
When("Put new value with positive TTL", func() {
|
||||
It("Should return the value before element expires", func() {
|
||||
cache := NewCache(WithCleanUpInterval[string](100 * time.Millisecond))
|
||||
cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond})
|
||||
v := "v1"
|
||||
cache.Put("key1", &v, 50*time.Millisecond)
|
||||
val, expiration := cache.Get("key1")
|
||||
|
@ -33,7 +33,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
Expect(cache.TotalCount()).Should(Equal(1))
|
||||
})
|
||||
It("Should return nil after expiration", func() {
|
||||
cache := NewCache(WithCleanUpInterval[string](100 * time.Millisecond))
|
||||
cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond})
|
||||
v := "v1"
|
||||
cache.Put("key1", &v, 50*time.Millisecond)
|
||||
|
||||
|
@ -52,7 +52,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
})
|
||||
When("Put new value without expiration", func() {
|
||||
It("Should not cache the value", func() {
|
||||
cache := NewCache(WithCleanUpInterval[string](50 * time.Millisecond))
|
||||
cache := NewCache[string](Options{CleanupInterval: 50 * time.Millisecond})
|
||||
v := "x"
|
||||
cache.Put("key1", &v, 0)
|
||||
val, expiration := cache.Get("key1")
|
||||
|
@ -63,7 +63,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
})
|
||||
When("Put updated value", func() {
|
||||
It("Should return updated value", func() {
|
||||
cache := NewCache[string]()
|
||||
cache := NewCache[string](Options{})
|
||||
v1 := "v1"
|
||||
v2 := "v2"
|
||||
cache.Put("key1", &v1, 50*time.Millisecond)
|
||||
|
@ -79,7 +79,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
})
|
||||
When("Purging after usage", func() {
|
||||
It("Should be empty after purge", func() {
|
||||
cache := NewCache[string]()
|
||||
cache := NewCache[string](Options{})
|
||||
v1 := "y"
|
||||
cache.Put("key1", &v1, time.Second)
|
||||
|
||||
|
@ -91,15 +91,62 @@ var _ = Describe("Expiration cache", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
Describe("Hook functions", func() {
|
||||
When("Hook functions are defined", func() {
|
||||
It("should call each hook function", func() {
|
||||
onCacheHitChannel := make(chan string, 10)
|
||||
onCacheMissChannel := make(chan string, 10)
|
||||
onAfterPutChannel := make(chan int, 10)
|
||||
cache := NewCache[string](Options{
|
||||
OnCacheHitFn: func(key string) {
|
||||
onCacheHitChannel <- key
|
||||
},
|
||||
OnCacheMissFn: func(key string) {
|
||||
onCacheMissChannel <- key
|
||||
},
|
||||
OnAfterPutFn: func(newSize int) {
|
||||
onAfterPutChannel <- newSize
|
||||
},
|
||||
})
|
||||
|
||||
By("Get non existing value", func() {
|
||||
val, _ := cache.Get("notExists")
|
||||
Expect(val).Should(BeNil())
|
||||
|
||||
Expect(onCacheMissChannel).Should(Receive(Equal("notExists")))
|
||||
Expect(onCacheHitChannel).Should(Not(Receive()))
|
||||
Expect(onAfterPutChannel).Should(Not(Receive()))
|
||||
})
|
||||
|
||||
By("Put new cache entry", func() {
|
||||
v1 := "v1"
|
||||
cache.Put("key1", &v1, time.Second)
|
||||
Expect(onCacheMissChannel).Should(Not(Receive()))
|
||||
Expect(onCacheMissChannel).Should(Not(Receive()))
|
||||
Expect(onAfterPutChannel).Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
By("Get existing value", func() {
|
||||
val, _ := cache.Get("key1")
|
||||
Expect(val).Should(HaveValue(Equal("v1")))
|
||||
|
||||
Expect(onCacheMissChannel).Should(Not(Receive()))
|
||||
Expect(onCacheHitChannel).Should(Receive(Equal("key1")))
|
||||
Expect(onAfterPutChannel).Should(Not(Receive()))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
Describe("preExpiration function", func() {
|
||||
When(" function is defined", func() {
|
||||
When("function is defined", func() {
|
||||
It("should update the value and TTL if function returns values", func() {
|
||||
fn := func(key string) (val *string, ttl time.Duration) {
|
||||
v2 := "v2"
|
||||
|
||||
return &v2, time.Second
|
||||
}
|
||||
cache := NewCache(WithOnExpiredFn(fn))
|
||||
|
||||
cache := NewCacheWithOnExpired[string](Options{}, fn)
|
||||
v1 := "v1"
|
||||
cache.Put("key1", &v1, 50*time.Millisecond)
|
||||
|
||||
|
@ -118,7 +165,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
|
||||
return &v2, time.Second
|
||||
}
|
||||
cache := NewCache(WithOnExpiredFn(fn))
|
||||
cache := NewCacheWithOnExpired[string](Options{}, fn)
|
||||
v1 := "somval"
|
||||
cache.Put("key1", &v1, time.Millisecond)
|
||||
|
||||
|
@ -139,7 +186,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
fn := func(key string) (val *string, ttl time.Duration) {
|
||||
return nil, 0
|
||||
}
|
||||
cache := NewCache(WithCleanUpInterval[string](100*time.Millisecond), WithOnExpiredFn(fn))
|
||||
cache := NewCacheWithOnExpired[string](Options{CleanupInterval: 100 * time.Microsecond}, fn)
|
||||
v1 := "z"
|
||||
cache.Put("key1", &v1, 50*time.Millisecond)
|
||||
|
||||
|
@ -152,7 +199,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
Describe("LRU behaviour", func() {
|
||||
When("Defined max size is reached", func() {
|
||||
It("should remove old elements", func() {
|
||||
cache := NewCache(WithMaxSize[string](3))
|
||||
cache := NewCache[string](Options{MaxSize: 3})
|
||||
|
||||
v1 := "val1"
|
||||
v2 := "val2"
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
package expirationcache
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type PrefetchingExpiringLRUCache[T any] struct {
|
||||
cache ExpiringCache[cacheValue[T]]
|
||||
prefetchingNameCache ExpiringCache[atomic.Uint32]
|
||||
reloadFn ReloadEntryFn[T]
|
||||
prefetchThreshold int
|
||||
prefetchExpires time.Duration
|
||||
onPrefetchEntryReloaded OnEntryReloadedCallback
|
||||
onPrefetchCacheHit OnCacheHitCallback
|
||||
}
|
||||
|
||||
type cacheValue[T any] struct {
|
||||
element *T
|
||||
prefetch bool
|
||||
}
|
||||
|
||||
// OnEntryReloadedCallback will be called if a prefetched entry is reloaded
|
||||
type OnEntryReloadedCallback func(key string)
|
||||
|
||||
// ReloadEntryFn reloads a prefetched entry by key
|
||||
type ReloadEntryFn[T any] func(key string) (*T, time.Duration)
|
||||
|
||||
type PrefetchingOptions[T any] struct {
|
||||
Options
|
||||
ReloadFn func(cacheKey string) (*T, time.Duration)
|
||||
PrefetchThreshold int
|
||||
PrefetchExpires time.Duration
|
||||
PrefetchMaxItemsCount int
|
||||
OnPrefetchAfterPut OnAfterPutCallback
|
||||
OnPrefetchEntryReloaded OnEntryReloadedCallback
|
||||
OnPrefetchCacheHit OnCacheHitCallback
|
||||
}
|
||||
|
||||
type PrefetchingCacheOption[T any] func(c *PrefetchingExpiringLRUCache[cacheValue[T]])
|
||||
|
||||
func NewPrefetchingCache[T any](options PrefetchingOptions[T]) *PrefetchingExpiringLRUCache[T] {
|
||||
pc := &PrefetchingExpiringLRUCache[T]{
|
||||
prefetchingNameCache: NewCache[atomic.Uint32](Options{
|
||||
CleanupInterval: time.Minute,
|
||||
MaxSize: uint(options.PrefetchMaxItemsCount),
|
||||
OnAfterPutFn: options.OnPrefetchAfterPut,
|
||||
}),
|
||||
prefetchExpires: options.PrefetchExpires,
|
||||
prefetchThreshold: options.PrefetchThreshold,
|
||||
reloadFn: options.ReloadFn,
|
||||
onPrefetchEntryReloaded: options.OnPrefetchEntryReloaded,
|
||||
onPrefetchCacheHit: options.OnPrefetchCacheHit,
|
||||
}
|
||||
|
||||
pc.cache = NewCacheWithOnExpired[cacheValue[T]](options.Options, pc.onExpired)
|
||||
|
||||
return pc
|
||||
}
|
||||
|
||||
// check if a cache entry should be prefetched: was queried > threshold in the time window
|
||||
func (e *PrefetchingExpiringLRUCache[T]) shouldPrefetch(cacheKey string) bool {
|
||||
if e.prefetchThreshold == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
cnt, _ := e.prefetchingNameCache.Get(cacheKey)
|
||||
|
||||
return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold)
|
||||
}
|
||||
|
||||
func (e *PrefetchingExpiringLRUCache[T]) onExpired(cacheKey string) (val *cacheValue[T], ttl time.Duration) {
|
||||
if e.shouldPrefetch(cacheKey) {
|
||||
loadedVal, ttl := e.reloadFn(cacheKey)
|
||||
if loadedVal != nil {
|
||||
if e.onPrefetchEntryReloaded != nil {
|
||||
e.onPrefetchEntryReloaded(cacheKey)
|
||||
}
|
||||
|
||||
return &cacheValue[T]{loadedVal, true}, ttl
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
func (e *PrefetchingExpiringLRUCache[T]) trackCacheKeyQueryCount(cacheKey string) {
|
||||
var x *atomic.Uint32
|
||||
if x, _ = e.prefetchingNameCache.Get(cacheKey); x == nil {
|
||||
x = &atomic.Uint32{}
|
||||
}
|
||||
|
||||
x.Add(1)
|
||||
e.prefetchingNameCache.Put(cacheKey, x, e.prefetchExpires)
|
||||
}
|
||||
|
||||
func (e *PrefetchingExpiringLRUCache[T]) Put(key string, val *T, expiration time.Duration) {
|
||||
e.cache.Put(key, &cacheValue[T]{element: val, prefetch: false}, expiration)
|
||||
}
|
||||
|
||||
// Get returns the value of cached entry with remained TTL. If entry is not cached, returns nil
|
||||
func (e *PrefetchingExpiringLRUCache[T]) Get(key string) (val *T, expiration time.Duration) {
|
||||
e.trackCacheKeyQueryCount(key)
|
||||
|
||||
res, exp := e.cache.Get(key)
|
||||
|
||||
if res == nil {
|
||||
return nil, exp
|
||||
}
|
||||
|
||||
if e.onPrefetchCacheHit != nil && res.prefetch {
|
||||
// Hit from prefetch cache
|
||||
e.onPrefetchCacheHit(key)
|
||||
}
|
||||
|
||||
return res.element, exp
|
||||
}
|
||||
|
||||
// TotalCount returns the total count of valid (not expired) elements
|
||||
func (e *PrefetchingExpiringLRUCache[T]) TotalCount() int {
|
||||
return e.cache.TotalCount()
|
||||
}
|
||||
|
||||
// Clear removes all cache entries
|
||||
func (e *PrefetchingExpiringLRUCache[T]) Clear() {
|
||||
e.cache.Clear()
|
||||
}
|
|
@ -0,0 +1,183 @@
|
|||
package expirationcache
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Prefetching expiration cache", func() {
|
||||
Describe("Basic operations", func() {
|
||||
When("string cache was created", func() {
|
||||
It("Initial cache should be empty", func() {
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
|
||||
Expect(cache.TotalCount()).Should(Equal(0))
|
||||
})
|
||||
It("Initial cache should not contain any elements", func() {
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
|
||||
val, expiration := cache.Get("key1")
|
||||
Expect(val).Should(BeNil())
|
||||
Expect(expiration).Should(Equal(time.Duration(0)))
|
||||
})
|
||||
|
||||
It("Should work as cache (basic operations)", func() {
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
|
||||
v := "v1"
|
||||
cache.Put("key1", &v, 50*time.Millisecond)
|
||||
|
||||
val, expiration := cache.Get("key1")
|
||||
Expect(val).Should(HaveValue(Equal("v1")))
|
||||
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
|
||||
|
||||
Expect(cache.TotalCount()).Should(Equal(1))
|
||||
|
||||
cache.Clear()
|
||||
|
||||
Expect(cache.TotalCount()).Should(Equal(0))
|
||||
})
|
||||
})
|
||||
Context("Prefetching", func() {
|
||||
It("Should prefetch element", func() {
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
|
||||
Options: Options{
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
},
|
||||
PrefetchThreshold: 2,
|
||||
PrefetchExpires: 100 * time.Millisecond,
|
||||
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
||||
v := "v2"
|
||||
|
||||
return &v, 50 * time.Millisecond
|
||||
},
|
||||
})
|
||||
By("put a value and get it again", func() {
|
||||
v := "v1"
|
||||
cache.Put("key1", &v, 50*time.Millisecond)
|
||||
val, expiration := cache.Get("key1")
|
||||
Expect(val).Should(HaveValue(Equal("v1")))
|
||||
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
|
||||
|
||||
Expect(cache.TotalCount()).Should(Equal(1))
|
||||
})
|
||||
By("Get value twice to trigger prefetching", func() {
|
||||
val, _ := cache.Get("key1")
|
||||
Expect(val).Should(HaveValue(Equal("v1")))
|
||||
|
||||
Eventually(func(g Gomega) {
|
||||
val, _ := cache.Get("key1")
|
||||
g.Expect(val).Should(HaveValue(Equal("v2")))
|
||||
}).Should(Succeed())
|
||||
})
|
||||
})
|
||||
It("Should not prefetch element", func() {
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
|
||||
Options: Options{
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
},
|
||||
PrefetchThreshold: 2,
|
||||
PrefetchExpires: 100 * time.Millisecond,
|
||||
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
||||
v := "v2"
|
||||
|
||||
return &v, 50 * time.Millisecond
|
||||
},
|
||||
})
|
||||
By("put a value and get it again", func() {
|
||||
v := "v1"
|
||||
cache.Put("key1", &v, 50*time.Millisecond)
|
||||
val, expiration := cache.Get("key1")
|
||||
Expect(val).Should(HaveValue(Equal("v1")))
|
||||
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
|
||||
|
||||
Expect(cache.TotalCount()).Should(Equal(1))
|
||||
})
|
||||
By("Wait for expiration -> the entry should not be prefetched, threshold was not reached", func() {
|
||||
Eventually(func(g Gomega) {
|
||||
val, _ := cache.Get("key1")
|
||||
g.Expect(val).Should(BeNil())
|
||||
}, "5s", "500ms").Should(Succeed())
|
||||
})
|
||||
})
|
||||
It("With default config (threshold = 0) should always prefetch", func() {
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
|
||||
Options: Options{
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
},
|
||||
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
||||
v := "v2"
|
||||
|
||||
return &v, 50 * time.Millisecond
|
||||
},
|
||||
})
|
||||
By("put a value and get it again", func() {
|
||||
v := "v1"
|
||||
cache.Put("key1", &v, 50*time.Millisecond)
|
||||
val, expiration := cache.Get("key1")
|
||||
Expect(val).Should(HaveValue(Equal("v1")))
|
||||
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
|
||||
})
|
||||
By("Should return new prefetched value after expiration", func() {
|
||||
Eventually(func(g Gomega) {
|
||||
val, _ := cache.Get("key1")
|
||||
g.Expect(val).Should(HaveValue(Equal("v2")))
|
||||
}, "5s").Should(Succeed())
|
||||
})
|
||||
})
|
||||
It("Should execute hook functions", func() {
|
||||
onPrefetchAfterPutChannel := make(chan int, 10)
|
||||
onPrefetchEntryReloaded := make(chan string, 10)
|
||||
onnPrefetchCacheHit := make(chan string, 10)
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
|
||||
Options: Options{
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
},
|
||||
PrefetchThreshold: 2,
|
||||
PrefetchExpires: 100 * time.Millisecond,
|
||||
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
||||
v := "v2"
|
||||
|
||||
return &v, 50 * time.Millisecond
|
||||
},
|
||||
OnPrefetchAfterPut: func(newSize int) { onPrefetchAfterPutChannel <- newSize },
|
||||
OnPrefetchEntryReloaded: func(key string) { onPrefetchEntryReloaded <- key },
|
||||
OnPrefetchCacheHit: func(key string) { onnPrefetchCacheHit <- key },
|
||||
})
|
||||
By("put a value", func() {
|
||||
v := "v1"
|
||||
cache.Put("key1", &v, 50*time.Millisecond)
|
||||
Expect(onPrefetchAfterPutChannel).Should(Not(Receive()))
|
||||
Expect(onPrefetchEntryReloaded).Should(Not(Receive()))
|
||||
Expect(onnPrefetchCacheHit).Should(Not(Receive()))
|
||||
})
|
||||
By("get a value 3 times to trigger prefetching", func() {
|
||||
// first get
|
||||
cache.Get("key1")
|
||||
|
||||
Expect(onPrefetchAfterPutChannel).Should(Receive(Equal(1)))
|
||||
Expect(onnPrefetchCacheHit).Should(Not(Receive()))
|
||||
Expect(onPrefetchEntryReloaded).Should(Not(Receive()))
|
||||
|
||||
// secont get
|
||||
val, _ := cache.Get("key1")
|
||||
Expect(val).Should(HaveValue(Equal("v1")))
|
||||
|
||||
// third get -> this should trigger prefetching after expiration
|
||||
cache.Get("key1")
|
||||
|
||||
// reload was executed
|
||||
Eventually(onPrefetchEntryReloaded).Should(Receive(Equal("key1")))
|
||||
Expect(onnPrefetchCacheHit).Should(Not(Receive()))
|
||||
// has new value
|
||||
Eventually(func(g Gomega) {
|
||||
val, _ := cache.Get("key1")
|
||||
g.Expect(val).Should(HaveValue(Equal("v2")))
|
||||
}, "5s").Should(Succeed())
|
||||
|
||||
// prefetch hit
|
||||
Eventually(onnPrefetchCacheHit).Should(Receive(Equal("key1")))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -20,10 +20,10 @@ const (
|
|||
// CachingPrefetchCacheHit fires if a query result was found in the prefetch cache, Parameter: domain name
|
||||
CachingPrefetchCacheHit = "caching:prefetchHit"
|
||||
|
||||
// CachingResultCacheHit fires, if a query result was found in the cache, Parameter: domain name
|
||||
// CachingResultCacheHit fires, if a query result was found in the cache
|
||||
CachingResultCacheHit = "caching:cacheHit"
|
||||
|
||||
// CachingResultCacheMiss fires, if a query result was not found in the cache, Parameter: domain name
|
||||
// CachingResultCacheMiss fires, if a query result was not found in the cache
|
||||
CachingResultCacheMiss = "caching:cacheMiss"
|
||||
|
||||
// CachingDomainsToPrefetchCountChanged fires, if a number of domains being prefetched changed, Parameter: new count
|
||||
|
|
|
@ -603,10 +603,11 @@ func (r *BlockingResolver) initFQDNIPCache() {
|
|||
identifiers = append(identifiers, identifier)
|
||||
}
|
||||
|
||||
r.fqdnIPCache = expirationcache.NewCache(expirationcache.WithCleanUpInterval[[]net.IP](defaultBlockingCleanUpInterval),
|
||||
expirationcache.WithOnExpiredFn(func(key string) (val *[]net.IP, ttl time.Duration) {
|
||||
return r.queryForFQIdentifierIPs(key)
|
||||
}))
|
||||
r.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](expirationcache.Options{
|
||||
CleanupInterval: defaultBlockingCleanUpInterval,
|
||||
}, func(key string) (val *[]net.IP, ttl time.Duration) {
|
||||
return r.queryForFQIdentifierIPs(key)
|
||||
})
|
||||
|
||||
for _, identifier := range identifiers {
|
||||
if isFQDN(identifier) {
|
||||
|
|
|
@ -29,15 +29,9 @@ type CachingResolver struct {
|
|||
|
||||
emitMetricEvents bool // disabled by Bootstrap
|
||||
|
||||
resultCache expirationcache.ExpiringCache[cacheValue]
|
||||
prefetchingNameCache expirationcache.ExpiringCache[int]
|
||||
redisClient *redis.Client
|
||||
}
|
||||
resultCache expirationcache.ExpiringCache[dns.Msg]
|
||||
|
||||
// cacheValue includes query answer and prefetch flag
|
||||
type cacheValue struct {
|
||||
resultMsg *dns.Msg
|
||||
prefetch bool
|
||||
redisClient *redis.Client
|
||||
}
|
||||
|
||||
// NewCachingResolver creates a new resolver instance
|
||||
|
@ -65,73 +59,76 @@ func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetri
|
|||
}
|
||||
|
||||
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
|
||||
cleanupOption := expirationcache.WithCleanUpInterval[cacheValue](defaultCachingCleanUpInterval)
|
||||
maxSizeOption := expirationcache.WithMaxSize[cacheValue](uint(cfg.MaxItemsCount))
|
||||
options := expirationcache.Options{
|
||||
CleanupInterval: defaultCachingCleanUpInterval,
|
||||
MaxSize: uint(cfg.MaxItemsCount),
|
||||
OnCacheHitFn: func(key string) {
|
||||
c.publishMetricsIfEnabled(evt.CachingResultCacheHit, key)
|
||||
},
|
||||
OnCacheMissFn: func(key string) {
|
||||
c.publishMetricsIfEnabled(evt.CachingResultCacheMiss, key)
|
||||
},
|
||||
OnAfterPutFn: func(newSize int) {
|
||||
c.publishMetricsIfEnabled(evt.CachingResultCacheChanged, newSize)
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.Prefetching {
|
||||
c.prefetchingNameCache = expirationcache.NewCache(
|
||||
expirationcache.WithCleanUpInterval[int](time.Minute),
|
||||
expirationcache.WithMaxSize[int](uint(cfg.PrefetchMaxItemsCount)),
|
||||
)
|
||||
prefetchingOptions := expirationcache.PrefetchingOptions[dns.Msg]{
|
||||
Options: options,
|
||||
PrefetchExpires: time.Duration(cfg.PrefetchExpires),
|
||||
PrefetchThreshold: cfg.PrefetchThreshold,
|
||||
PrefetchMaxItemsCount: cfg.PrefetchMaxItemsCount,
|
||||
ReloadFn: c.reloadCacheEntry,
|
||||
OnPrefetchAfterPut: func(newSize int) {
|
||||
c.publishMetricsIfEnabled(evt.CachingDomainsToPrefetchCountChanged, newSize)
|
||||
},
|
||||
OnPrefetchEntryReloaded: func(key string) {
|
||||
c.publishMetricsIfEnabled(evt.CachingDomainPrefetched, key)
|
||||
},
|
||||
OnPrefetchCacheHit: func(key string) {
|
||||
c.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, key)
|
||||
},
|
||||
}
|
||||
|
||||
c.resultCache = expirationcache.NewCache(
|
||||
cleanupOption,
|
||||
maxSizeOption,
|
||||
expirationcache.WithOnExpiredFn(c.onExpired),
|
||||
)
|
||||
c.resultCache = expirationcache.NewPrefetchingCache(prefetchingOptions)
|
||||
} else {
|
||||
c.resultCache = expirationcache.NewCache(cleanupOption, maxSizeOption)
|
||||
c.resultCache = expirationcache.NewCache[dns.Msg](options)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*dns.Msg, time.Duration) {
|
||||
qType, domainName := util.ExtractCacheKey(cacheKey)
|
||||
logger := r.log()
|
||||
|
||||
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
|
||||
|
||||
req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger)
|
||||
response, err := r.next.Resolve(req)
|
||||
|
||||
if err == nil {
|
||||
if response.Res.Rcode == dns.RcodeSuccess {
|
||||
return response.Res, r.adjustTTLs(response.Res.Answer)
|
||||
}
|
||||
} else {
|
||||
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
|
||||
}
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
func setupRedisCacheSubscriber(c *CachingResolver) {
|
||||
go func() {
|
||||
for rc := range c.redisClient.CacheChannel {
|
||||
if rc != nil {
|
||||
c.log().Debug("Received key from redis: ", rc.Key)
|
||||
ttl := c.adjustTTLs(rc.Response.Res.Answer)
|
||||
c.putInCache(rc.Key, rc.Response, ttl, false, false)
|
||||
c.putInCache(rc.Key, rc.Response, ttl, false)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// check if domain was queried > threshold in the time window
|
||||
func (r *CachingResolver) shouldPrefetch(cacheKey string) bool {
|
||||
if r.cfg.PrefetchThreshold == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
cnt, _ := r.prefetchingNameCache.Get(cacheKey)
|
||||
|
||||
return cnt != nil && *cnt > r.cfg.PrefetchThreshold
|
||||
}
|
||||
|
||||
func (r *CachingResolver) onExpired(cacheKey string) (val *cacheValue, ttl time.Duration) {
|
||||
qType, domainName := util.ExtractCacheKey(cacheKey)
|
||||
|
||||
if r.shouldPrefetch(cacheKey) {
|
||||
logger := r.log()
|
||||
|
||||
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
|
||||
|
||||
req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger)
|
||||
response, err := r.next.Resolve(req)
|
||||
|
||||
if err == nil {
|
||||
if response.Res.Rcode == dns.RcodeSuccess {
|
||||
r.publishMetricsIfEnabled(evt.CachingDomainPrefetched, domainName)
|
||||
|
||||
return &cacheValue{response.Res, true}, r.adjustTTLs(response.Res.Answer)
|
||||
}
|
||||
} else {
|
||||
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
|
||||
r.cfg.LogConfig(logger)
|
||||
|
@ -155,23 +152,14 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
|
|||
cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain)
|
||||
logger := logger.WithField("domain", util.Obfuscate(domain))
|
||||
|
||||
r.trackQueryDomainNameCount(domain, cacheKey, logger)
|
||||
|
||||
val, ttl := r.resultCache.Get(cacheKey)
|
||||
|
||||
if val != nil {
|
||||
logger.Debug("domain is cached")
|
||||
|
||||
r.publishMetricsIfEnabled(evt.CachingResultCacheHit, domain)
|
||||
|
||||
if val.prefetch {
|
||||
// Hit from prefetch cache
|
||||
r.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, domain)
|
||||
}
|
||||
|
||||
resp := val.resultMsg.Copy()
|
||||
resp := val.Copy()
|
||||
resp.SetReply(request.Req)
|
||||
resp.Rcode = val.resultMsg.Rcode
|
||||
resp.Rcode = val.Rcode
|
||||
|
||||
// Adjust TTL
|
||||
setTTLInCachedResponse(resp, ttl)
|
||||
|
@ -183,14 +171,12 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
|
|||
return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
|
||||
}
|
||||
|
||||
r.publishMetricsIfEnabled(evt.CachingResultCacheMiss, domain)
|
||||
|
||||
logger.WithField("next_resolver", Name(r.next)).Debug("not in cache: go to next resolver")
|
||||
response, err = r.next.Resolve(request)
|
||||
|
||||
if err == nil {
|
||||
cacheTTL := r.adjustTTLs(response.Res.Answer)
|
||||
r.putInCache(cacheKey, response, cacheTTL, false, true)
|
||||
r.putInCache(cacheKey, response, cacheTTL, true)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -209,22 +195,6 @@ func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, logger *logrus.Entry) {
|
||||
if r.prefetchingNameCache != nil {
|
||||
var domainCount int
|
||||
if x, _ := r.prefetchingNameCache.Get(cacheKey); x != nil {
|
||||
domainCount = *x
|
||||
}
|
||||
domainCount++
|
||||
r.prefetchingNameCache.Put(cacheKey, &domainCount, r.cfg.PrefetchExpires.ToDuration())
|
||||
totalCount := r.prefetchingNameCache.TotalCount()
|
||||
|
||||
logger.Debugf("domain '%s' was requested %d times, "+
|
||||
"total cache size: %d", util.Obfuscate(domain), domainCount, totalCount)
|
||||
r.publishMetricsIfEnabled(evt.CachingDomainsToPrefetchCountChanged, totalCount)
|
||||
}
|
||||
}
|
||||
|
||||
// removes EDNS OPT records from message
|
||||
func removeEdns0Extra(msg *dns.Msg) {
|
||||
if len(msg.Extra) > 0 {
|
||||
|
@ -246,7 +216,7 @@ func shouldBeCached(msg *dns.Msg) bool {
|
|||
}
|
||||
|
||||
func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, ttl time.Duration,
|
||||
prefetch, publish bool,
|
||||
publish bool,
|
||||
) {
|
||||
respCopy := response.Res.Copy()
|
||||
|
||||
|
@ -255,16 +225,14 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response,
|
|||
|
||||
if response.Res.Rcode == dns.RcodeSuccess && shouldBeCached(response.Res) {
|
||||
// put value into cache
|
||||
r.resultCache.Put(cacheKey, &cacheValue{respCopy, prefetch}, ttl)
|
||||
r.resultCache.Put(cacheKey, respCopy, ttl)
|
||||
} else if response.Res.Rcode == dns.RcodeNameError {
|
||||
if r.cfg.CacheTimeNegative.IsAboveZero() {
|
||||
// put negative cache if result code is NXDOMAIN
|
||||
r.resultCache.Put(cacheKey, &cacheValue{respCopy, prefetch}, r.cfg.CacheTimeNegative.ToDuration())
|
||||
r.resultCache.Put(cacheKey, respCopy, r.cfg.CacheTimeNegative.ToDuration())
|
||||
}
|
||||
}
|
||||
|
||||
r.publishMetricsIfEnabled(evt.CachingResultCacheChanged, r.resultCache.TotalCount())
|
||||
|
||||
if publish && r.redisClient != nil {
|
||||
res := *respCopy
|
||||
r.redisClient.PublishCache(cacheKey, &res)
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/cache/expirationcache"
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/evt"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
|
@ -81,18 +80,15 @@ var _ = Describe("CachingResolver", func() {
|
|||
// prepare resolver, set smaller caching times for testing
|
||||
prefetchThreshold := 5
|
||||
configureCaches(sut, &sutConfig)
|
||||
sut.resultCache = expirationcache.NewCache(
|
||||
expirationcache.WithCleanUpInterval[cacheValue](100*time.Millisecond),
|
||||
expirationcache.WithOnExpiredFn(sut.onExpired))
|
||||
|
||||
domainPrefetched := make(chan string, 1)
|
||||
prefetchHitDomain := make(chan string, 1)
|
||||
domainPrefetched := make(chan bool, 1)
|
||||
prefetchHitDomain := make(chan bool, 1)
|
||||
prefetchedCnt := make(chan int, 1)
|
||||
Expect(Bus().SubscribeOnce(CachingPrefetchCacheHit, func(domain string) {
|
||||
prefetchHitDomain <- domain
|
||||
prefetchHitDomain <- true
|
||||
})).Should(Succeed())
|
||||
Expect(Bus().SubscribeOnce(CachingDomainPrefetched, func(domain string) {
|
||||
domainPrefetched <- domain
|
||||
domainPrefetched <- true
|
||||
})).Should(Succeed())
|
||||
|
||||
Expect(Bus().SubscribeOnce(CachingDomainsToPrefetchCountChanged, func(cnt int) {
|
||||
|
@ -115,7 +111,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
}
|
||||
|
||||
// now is this domain prefetched
|
||||
Eventually(domainPrefetched, "4s").Should(Receive(Equal("example.com")))
|
||||
Eventually(domainPrefetched, "10s").Should(Receive(Equal(true)))
|
||||
|
||||
// and it should hit from prefetch cache
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
|
@ -125,16 +121,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
HaveReturnCode(dns.RcodeSuccess),
|
||||
BeDNSRecord("example.com.", A, "123.122.121.120"),
|
||||
HaveTTL(BeNumerically("<=", 2))))
|
||||
Eventually(prefetchHitDomain, "4s").Should(Receive(Equal("example.com")))
|
||||
})
|
||||
When("threshold is 0", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig.PrefetchThreshold = 0
|
||||
})
|
||||
|
||||
It("should always prefetch", func() {
|
||||
Expect(sut.shouldPrefetch("domain.tld")).Should(BeTrue())
|
||||
})
|
||||
Eventually(prefetchHitDomain, "10s").Should(Receive(Equal(true)))
|
||||
})
|
||||
})
|
||||
When("caching with default values is enabled", func() {
|
||||
|
@ -204,9 +191,9 @@ var _ = Describe("CachingResolver", func() {
|
|||
|
||||
It("should cache response and use response's TTL", func() {
|
||||
By("first request", func() {
|
||||
domain := make(chan string, 1)
|
||||
domain := make(chan bool, 1)
|
||||
_ = Bus().SubscribeOnce(CachingResultCacheMiss, func(d string) {
|
||||
domain <- d
|
||||
domain <- true
|
||||
})
|
||||
|
||||
totalCacheCount := make(chan int, 1)
|
||||
|
@ -223,15 +210,15 @@ var _ = Describe("CachingResolver", func() {
|
|||
|
||||
Expect(m.Calls).Should(HaveLen(1))
|
||||
|
||||
Expect(domain).Should(Receive(Equal("example.com")))
|
||||
Expect(domain).Should(Receive(Equal(true)))
|
||||
Expect(totalCacheCount).Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
By("second request", func() {
|
||||
Eventually(func(g Gomega) {
|
||||
domain := make(chan string, 1)
|
||||
domain := make(chan bool, 1)
|
||||
_ = Bus().SubscribeOnce(CachingResultCacheHit, func(d string) {
|
||||
domain <- d
|
||||
domain <- true
|
||||
})
|
||||
|
||||
g.Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
|
@ -246,7 +233,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
// still one call to upstream
|
||||
g.Expect(m.Calls).Should(HaveLen(1))
|
||||
|
||||
g.Expect(domain).Should(Receive(Equal("example.com")))
|
||||
g.Expect(domain).Should(Receive(Equal(true)))
|
||||
}, "1s").Should(Succeed())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -41,7 +41,9 @@ func NewClientNamesResolver(
|
|||
configurable: withConfig(&cfg),
|
||||
typed: withType("client_names"),
|
||||
|
||||
cache: expirationcache.NewCache(expirationcache.WithCleanUpInterval[[]string](time.Hour)),
|
||||
cache: expirationcache.NewCache[[]string](expirationcache.Options{
|
||||
CleanupInterval: time.Hour,
|
||||
}),
|
||||
externalResolver: r,
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue