From 497bd0d0fd1a2afb58cfd36a3c37c8e5aca9a467 Mon Sep 17 00:00:00 2001 From: Dimitri Herzog Date: Sat, 30 Sep 2023 17:14:59 +0200 Subject: [PATCH] chore(refactor): refactor cache implementation (#1174) * chore(refactor): refactor cache implementation * chore: use atomic.Uint32 as prefetch names query count Co-authored-by: ThinkChaos --------- Co-authored-by: ThinkChaos --- cache/expirationcache/expiration_cache.go | 80 +++++--- .../expirationcache/expiration_cache_test.go | 71 +++++-- cache/expirationcache/prefetching_cache.go | 127 ++++++++++++ .../expirationcache/prefetching_cache_test.go | 183 ++++++++++++++++++ evt/events.go | 4 +- resolver/blocking_resolver.go | 9 +- resolver/caching_resolver.go | 152 ++++++--------- resolver/caching_resolver_test.go | 37 ++-- resolver/client_names_resolver.go | 4 +- 9 files changed, 508 insertions(+), 159 deletions(-) create mode 100644 cache/expirationcache/prefetching_cache.go create mode 100644 cache/expirationcache/prefetching_cache_test.go diff --git a/cache/expirationcache/expiration_cache.go b/cache/expirationcache/expiration_cache.go index 055c51d1..01b13449 100644 --- a/cache/expirationcache/expiration_cache.go +++ b/cache/expirationcache/expiration_cache.go @@ -19,15 +19,18 @@ type element[T any] struct { type ExpiringLRUCache[T any] struct { cleanUpInterval time.Duration preExpirationFn OnExpirationCallback[T] + onCacheHit OnCacheHitCallback + onCacheMiss OnCacheMissCallback + onAfterPut OnAfterPutCallback lru *lru.Cache } -type CacheOption[T any] func(c *ExpiringLRUCache[T]) - -func WithCleanUpInterval[T any](d time.Duration) CacheOption[T] { - return func(e *ExpiringLRUCache[T]) { - e.cleanUpInterval = d - } +type Options struct { + OnCacheHitFn OnCacheHitCallback + OnCacheMissFn OnCacheMissCallback + OnAfterPutFn OnAfterPutCallback + CleanupInterval time.Duration + MaxSize uint } // OnExpirationCallback will be called just before an element gets expired and will @@ -35,33 +38,56 @@ func WithCleanUpInterval[T any](d time.Duration) CacheOption[T] { // element in the cache or nil to remove it type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration) -func WithOnExpiredFn[T any](fn OnExpirationCallback[T]) CacheOption[T] { - return func(c *ExpiringLRUCache[T]) { - c.preExpirationFn = fn - } +// OnCacheHitCallback will be called on cache get if entry was found +type OnCacheHitCallback func(key string) + +// OnCacheMissCallback will be called on cache get and entry was not found +type OnCacheMissCallback func(key string) + +// OnAfterPutCallback will be called after put, receives new element count as parameter +type OnAfterPutCallback func(newSize int) + +func NewCache[T any](options Options) *ExpiringLRUCache[T] { + return NewCacheWithOnExpired[T](options, nil) } -func WithMaxSize[T any](size uint) CacheOption[T] { - return func(c *ExpiringLRUCache[T]) { - if size > 0 { - l, _ := lru.New(int(size)) - c.lru = l - } - } -} - -func NewCache[T any](options ...CacheOption[T]) *ExpiringLRUCache[T] { +func NewCacheWithOnExpired[T any](options Options, + onExpirationFn OnExpirationCallback[T], +) *ExpiringLRUCache[T] { l, _ := lru.New(defaultSize) c := &ExpiringLRUCache[T]{ cleanUpInterval: defaultCleanUpInterval, preExpirationFn: func(key string) (val *T, ttl time.Duration) { return nil, 0 }, - lru: l, + onCacheHit: func(key string) {}, + onCacheMiss: func(key string) {}, + lru: l, } - for _, opt := range options { - opt(c) + if options.CleanupInterval > 0 { + c.cleanUpInterval = options.CleanupInterval + } + + if options.MaxSize > 0 { + l, _ := lru.New(int(options.MaxSize)) + c.lru = l + } + + if options.OnAfterPutFn != nil { + c.onAfterPut = options.OnAfterPutFn + } + + if options.OnCacheHitFn != nil { + c.onCacheHit = options.OnCacheHitFn + } + + if options.OnCacheMissFn != nil { + c.onCacheMiss = options.OnCacheMissFn + } + + if onExpirationFn != nil { + c.preExpirationFn = onExpirationFn } go periodicCleanup(c) @@ -122,15 +148,23 @@ func (e *ExpiringLRUCache[T]) Put(key string, val *T, ttl time.Duration) { val: val, expiresEpochMs: expiresEpochMs, }) + + if e.onAfterPut != nil { + e.onAfterPut(e.lru.Len()) + } } func (e *ExpiringLRUCache[T]) Get(key string) (val *T, ttl time.Duration) { el, found := e.lru.Get(key) if found { + e.onCacheHit(key) + return el.(*element[T]).val, calculateRemainTTL(el.(*element[T]).expiresEpochMs) } + e.onCacheMiss(key) + return nil, 0 } diff --git a/cache/expirationcache/expiration_cache_test.go b/cache/expirationcache/expiration_cache_test.go index a1f92ab5..dad89e62 100644 --- a/cache/expirationcache/expiration_cache_test.go +++ b/cache/expirationcache/expiration_cache_test.go @@ -11,11 +11,11 @@ var _ = Describe("Expiration cache", func() { Describe("Basic operations", func() { When("string cache was created", func() { It("Initial cache should be empty", func() { - cache := NewCache[string]() + cache := NewCache[string](Options{}) Expect(cache.TotalCount()).Should(Equal(0)) }) It("Initial cache should not contain any elements", func() { - cache := NewCache[string]() + cache := NewCache[string](Options{}) val, expiration := cache.Get("key1") Expect(val).Should(BeNil()) Expect(expiration).Should(Equal(time.Duration(0))) @@ -23,7 +23,7 @@ var _ = Describe("Expiration cache", func() { }) When("Put new value with positive TTL", func() { It("Should return the value before element expires", func() { - cache := NewCache(WithCleanUpInterval[string](100 * time.Millisecond)) + cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond}) v := "v1" cache.Put("key1", &v, 50*time.Millisecond) val, expiration := cache.Get("key1") @@ -33,7 +33,7 @@ var _ = Describe("Expiration cache", func() { Expect(cache.TotalCount()).Should(Equal(1)) }) It("Should return nil after expiration", func() { - cache := NewCache(WithCleanUpInterval[string](100 * time.Millisecond)) + cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond}) v := "v1" cache.Put("key1", &v, 50*time.Millisecond) @@ -52,7 +52,7 @@ var _ = Describe("Expiration cache", func() { }) When("Put new value without expiration", func() { It("Should not cache the value", func() { - cache := NewCache(WithCleanUpInterval[string](50 * time.Millisecond)) + cache := NewCache[string](Options{CleanupInterval: 50 * time.Millisecond}) v := "x" cache.Put("key1", &v, 0) val, expiration := cache.Get("key1") @@ -63,7 +63,7 @@ var _ = Describe("Expiration cache", func() { }) When("Put updated value", func() { It("Should return updated value", func() { - cache := NewCache[string]() + cache := NewCache[string](Options{}) v1 := "v1" v2 := "v2" cache.Put("key1", &v1, 50*time.Millisecond) @@ -79,7 +79,7 @@ var _ = Describe("Expiration cache", func() { }) When("Purging after usage", func() { It("Should be empty after purge", func() { - cache := NewCache[string]() + cache := NewCache[string](Options{}) v1 := "y" cache.Put("key1", &v1, time.Second) @@ -91,15 +91,62 @@ var _ = Describe("Expiration cache", func() { }) }) }) + Describe("Hook functions", func() { + When("Hook functions are defined", func() { + It("should call each hook function", func() { + onCacheHitChannel := make(chan string, 10) + onCacheMissChannel := make(chan string, 10) + onAfterPutChannel := make(chan int, 10) + cache := NewCache[string](Options{ + OnCacheHitFn: func(key string) { + onCacheHitChannel <- key + }, + OnCacheMissFn: func(key string) { + onCacheMissChannel <- key + }, + OnAfterPutFn: func(newSize int) { + onAfterPutChannel <- newSize + }, + }) + + By("Get non existing value", func() { + val, _ := cache.Get("notExists") + Expect(val).Should(BeNil()) + + Expect(onCacheMissChannel).Should(Receive(Equal("notExists"))) + Expect(onCacheHitChannel).Should(Not(Receive())) + Expect(onAfterPutChannel).Should(Not(Receive())) + }) + + By("Put new cache entry", func() { + v1 := "v1" + cache.Put("key1", &v1, time.Second) + Expect(onCacheMissChannel).Should(Not(Receive())) + Expect(onCacheMissChannel).Should(Not(Receive())) + Expect(onAfterPutChannel).Should(Receive(Equal(1))) + }) + + By("Get existing value", func() { + val, _ := cache.Get("key1") + Expect(val).Should(HaveValue(Equal("v1"))) + + Expect(onCacheMissChannel).Should(Not(Receive())) + Expect(onCacheHitChannel).Should(Receive(Equal("key1"))) + Expect(onAfterPutChannel).Should(Not(Receive())) + }) + }) + }) + }) Describe("preExpiration function", func() { - When(" function is defined", func() { + When("function is defined", func() { It("should update the value and TTL if function returns values", func() { fn := func(key string) (val *string, ttl time.Duration) { v2 := "v2" return &v2, time.Second } - cache := NewCache(WithOnExpiredFn(fn)) + + cache := NewCacheWithOnExpired[string](Options{}, fn) v1 := "v1" cache.Put("key1", &v1, 50*time.Millisecond) @@ -118,7 +165,7 @@ var _ = Describe("Expiration cache", func() { return &v2, time.Second } - cache := NewCache(WithOnExpiredFn(fn)) + cache := NewCacheWithOnExpired[string](Options{}, fn) v1 := "somval" cache.Put("key1", &v1, time.Millisecond) @@ -139,7 +186,7 @@ var _ = Describe("Expiration cache", func() { fn := func(key string) (val *string, ttl time.Duration) { return nil, 0 } - cache := NewCache(WithCleanUpInterval[string](100*time.Millisecond), WithOnExpiredFn(fn)) + cache := NewCacheWithOnExpired[string](Options{CleanupInterval: 100 * time.Microsecond}, fn) v1 := "z" cache.Put("key1", &v1, 50*time.Millisecond) @@ -152,7 +199,7 @@ var _ = Describe("Expiration cache", func() { Describe("LRU behaviour", func() { When("Defined max size is reached", func() { It("should remove old elements", func() { - cache := NewCache(WithMaxSize[string](3)) + cache := NewCache[string](Options{MaxSize: 3}) v1 := "val1" v2 := "val2" diff --git a/cache/expirationcache/prefetching_cache.go b/cache/expirationcache/prefetching_cache.go new file mode 100644 index 00000000..323c8457 --- /dev/null +++ b/cache/expirationcache/prefetching_cache.go @@ -0,0 +1,127 @@ +package expirationcache + +import ( + "sync/atomic" + "time" +) + +type PrefetchingExpiringLRUCache[T any] struct { + cache ExpiringCache[cacheValue[T]] + prefetchingNameCache ExpiringCache[atomic.Uint32] + reloadFn ReloadEntryFn[T] + prefetchThreshold int + prefetchExpires time.Duration + onPrefetchEntryReloaded OnEntryReloadedCallback + onPrefetchCacheHit OnCacheHitCallback +} + +type cacheValue[T any] struct { + element *T + prefetch bool +} + +// OnEntryReloadedCallback will be called if a prefetched entry is reloaded +type OnEntryReloadedCallback func(key string) + +// ReloadEntryFn reloads a prefetched entry by key +type ReloadEntryFn[T any] func(key string) (*T, time.Duration) + +type PrefetchingOptions[T any] struct { + Options + ReloadFn func(cacheKey string) (*T, time.Duration) + PrefetchThreshold int + PrefetchExpires time.Duration + PrefetchMaxItemsCount int + OnPrefetchAfterPut OnAfterPutCallback + OnPrefetchEntryReloaded OnEntryReloadedCallback + OnPrefetchCacheHit OnCacheHitCallback +} + +type PrefetchingCacheOption[T any] func(c *PrefetchingExpiringLRUCache[cacheValue[T]]) + +func NewPrefetchingCache[T any](options PrefetchingOptions[T]) *PrefetchingExpiringLRUCache[T] { + pc := &PrefetchingExpiringLRUCache[T]{ + prefetchingNameCache: NewCache[atomic.Uint32](Options{ + CleanupInterval: time.Minute, + MaxSize: uint(options.PrefetchMaxItemsCount), + OnAfterPutFn: options.OnPrefetchAfterPut, + }), + prefetchExpires: options.PrefetchExpires, + prefetchThreshold: options.PrefetchThreshold, + reloadFn: options.ReloadFn, + onPrefetchEntryReloaded: options.OnPrefetchEntryReloaded, + onPrefetchCacheHit: options.OnPrefetchCacheHit, + } + + pc.cache = NewCacheWithOnExpired[cacheValue[T]](options.Options, pc.onExpired) + + return pc +} + +// check if a cache entry should be prefetched: was queried > threshold in the time window +func (e *PrefetchingExpiringLRUCache[T]) shouldPrefetch(cacheKey string) bool { + if e.prefetchThreshold == 0 { + return true + } + + cnt, _ := e.prefetchingNameCache.Get(cacheKey) + + return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold) +} + +func (e *PrefetchingExpiringLRUCache[T]) onExpired(cacheKey string) (val *cacheValue[T], ttl time.Duration) { + if e.shouldPrefetch(cacheKey) { + loadedVal, ttl := e.reloadFn(cacheKey) + if loadedVal != nil { + if e.onPrefetchEntryReloaded != nil { + e.onPrefetchEntryReloaded(cacheKey) + } + + return &cacheValue[T]{loadedVal, true}, ttl + } + } + + return nil, 0 +} + +func (e *PrefetchingExpiringLRUCache[T]) trackCacheKeyQueryCount(cacheKey string) { + var x *atomic.Uint32 + if x, _ = e.prefetchingNameCache.Get(cacheKey); x == nil { + x = &atomic.Uint32{} + } + + x.Add(1) + e.prefetchingNameCache.Put(cacheKey, x, e.prefetchExpires) +} + +func (e *PrefetchingExpiringLRUCache[T]) Put(key string, val *T, expiration time.Duration) { + e.cache.Put(key, &cacheValue[T]{element: val, prefetch: false}, expiration) +} + +// Get returns the value of cached entry with remained TTL. If entry is not cached, returns nil +func (e *PrefetchingExpiringLRUCache[T]) Get(key string) (val *T, expiration time.Duration) { + e.trackCacheKeyQueryCount(key) + + res, exp := e.cache.Get(key) + + if res == nil { + return nil, exp + } + + if e.onPrefetchCacheHit != nil && res.prefetch { + // Hit from prefetch cache + e.onPrefetchCacheHit(key) + } + + return res.element, exp +} + +// TotalCount returns the total count of valid (not expired) elements +func (e *PrefetchingExpiringLRUCache[T]) TotalCount() int { + return e.cache.TotalCount() +} + +// Clear removes all cache entries +func (e *PrefetchingExpiringLRUCache[T]) Clear() { + e.cache.Clear() +} diff --git a/cache/expirationcache/prefetching_cache_test.go b/cache/expirationcache/prefetching_cache_test.go new file mode 100644 index 00000000..7a1ef80b --- /dev/null +++ b/cache/expirationcache/prefetching_cache_test.go @@ -0,0 +1,183 @@ +package expirationcache + +import ( + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Prefetching expiration cache", func() { + Describe("Basic operations", func() { + When("string cache was created", func() { + It("Initial cache should be empty", func() { + cache := NewPrefetchingCache[string](PrefetchingOptions[string]{}) + Expect(cache.TotalCount()).Should(Equal(0)) + }) + It("Initial cache should not contain any elements", func() { + cache := NewPrefetchingCache[string](PrefetchingOptions[string]{}) + val, expiration := cache.Get("key1") + Expect(val).Should(BeNil()) + Expect(expiration).Should(Equal(time.Duration(0))) + }) + + It("Should work as cache (basic operations)", func() { + cache := NewPrefetchingCache[string](PrefetchingOptions[string]{}) + v := "v1" + cache.Put("key1", &v, 50*time.Millisecond) + + val, expiration := cache.Get("key1") + Expect(val).Should(HaveValue(Equal("v1"))) + Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50)) + + Expect(cache.TotalCount()).Should(Equal(1)) + + cache.Clear() + + Expect(cache.TotalCount()).Should(Equal(0)) + }) + }) + Context("Prefetching", func() { + It("Should prefetch element", func() { + cache := NewPrefetchingCache[string](PrefetchingOptions[string]{ + Options: Options{ + CleanupInterval: 100 * time.Millisecond, + }, + PrefetchThreshold: 2, + PrefetchExpires: 100 * time.Millisecond, + ReloadFn: func(cacheKey string) (*string, time.Duration) { + v := "v2" + + return &v, 50 * time.Millisecond + }, + }) + By("put a value and get it again", func() { + v := "v1" + cache.Put("key1", &v, 50*time.Millisecond) + val, expiration := cache.Get("key1") + Expect(val).Should(HaveValue(Equal("v1"))) + Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50)) + + Expect(cache.TotalCount()).Should(Equal(1)) + }) + By("Get value twice to trigger prefetching", func() { + val, _ := cache.Get("key1") + Expect(val).Should(HaveValue(Equal("v1"))) + + Eventually(func(g Gomega) { + val, _ := cache.Get("key1") + g.Expect(val).Should(HaveValue(Equal("v2"))) + }).Should(Succeed()) + }) + }) + It("Should not prefetch element", func() { + cache := NewPrefetchingCache[string](PrefetchingOptions[string]{ + Options: Options{ + CleanupInterval: 100 * time.Millisecond, + }, + PrefetchThreshold: 2, + PrefetchExpires: 100 * time.Millisecond, + ReloadFn: func(cacheKey string) (*string, time.Duration) { + v := "v2" + + return &v, 50 * time.Millisecond + }, + }) + By("put a value and get it again", func() { + v := "v1" + cache.Put("key1", &v, 50*time.Millisecond) + val, expiration := cache.Get("key1") + Expect(val).Should(HaveValue(Equal("v1"))) + Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50)) + + Expect(cache.TotalCount()).Should(Equal(1)) + }) + By("Wait for expiration -> the entry should not be prefetched, threshold was not reached", func() { + Eventually(func(g Gomega) { + val, _ := cache.Get("key1") + g.Expect(val).Should(BeNil()) + }, "5s", "500ms").Should(Succeed()) + }) + }) + It("With default config (threshold = 0) should always prefetch", func() { + cache := NewPrefetchingCache[string](PrefetchingOptions[string]{ + Options: Options{ + CleanupInterval: 100 * time.Millisecond, + }, + ReloadFn: func(cacheKey string) (*string, time.Duration) { + v := "v2" + + return &v, 50 * time.Millisecond + }, + }) + By("put a value and get it again", func() { + v := "v1" + cache.Put("key1", &v, 50*time.Millisecond) + val, expiration := cache.Get("key1") + Expect(val).Should(HaveValue(Equal("v1"))) + Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50)) + }) + By("Should return new prefetched value after expiration", func() { + Eventually(func(g Gomega) { + val, _ := cache.Get("key1") + g.Expect(val).Should(HaveValue(Equal("v2"))) + }, "5s").Should(Succeed()) + }) + }) + It("Should execute hook functions", func() { + onPrefetchAfterPutChannel := make(chan int, 10) + onPrefetchEntryReloaded := make(chan string, 10) + onnPrefetchCacheHit := make(chan string, 10) + cache := NewPrefetchingCache[string](PrefetchingOptions[string]{ + Options: Options{ + CleanupInterval: 100 * time.Millisecond, + }, + PrefetchThreshold: 2, + PrefetchExpires: 100 * time.Millisecond, + ReloadFn: func(cacheKey string) (*string, time.Duration) { + v := "v2" + + return &v, 50 * time.Millisecond + }, + OnPrefetchAfterPut: func(newSize int) { onPrefetchAfterPutChannel <- newSize }, + OnPrefetchEntryReloaded: func(key string) { onPrefetchEntryReloaded <- key }, + OnPrefetchCacheHit: func(key string) { onnPrefetchCacheHit <- key }, + }) + By("put a value", func() { + v := "v1" + cache.Put("key1", &v, 50*time.Millisecond) + Expect(onPrefetchAfterPutChannel).Should(Not(Receive())) + Expect(onPrefetchEntryReloaded).Should(Not(Receive())) + Expect(onnPrefetchCacheHit).Should(Not(Receive())) + }) + By("get a value 3 times to trigger prefetching", func() { + // first get + cache.Get("key1") + + Expect(onPrefetchAfterPutChannel).Should(Receive(Equal(1))) + Expect(onnPrefetchCacheHit).Should(Not(Receive())) + Expect(onPrefetchEntryReloaded).Should(Not(Receive())) + + // secont get + val, _ := cache.Get("key1") + Expect(val).Should(HaveValue(Equal("v1"))) + + // third get -> this should trigger prefetching after expiration + cache.Get("key1") + + // reload was executed + Eventually(onPrefetchEntryReloaded).Should(Receive(Equal("key1"))) + Expect(onnPrefetchCacheHit).Should(Not(Receive())) + // has new value + Eventually(func(g Gomega) { + val, _ := cache.Get("key1") + g.Expect(val).Should(HaveValue(Equal("v2"))) + }, "5s").Should(Succeed()) + + // prefetch hit + Eventually(onnPrefetchCacheHit).Should(Receive(Equal("key1"))) + }) + }) + }) + }) +}) diff --git a/evt/events.go b/evt/events.go index 9fcc9136..b55a5cda 100644 --- a/evt/events.go +++ b/evt/events.go @@ -20,10 +20,10 @@ const ( // CachingPrefetchCacheHit fires if a query result was found in the prefetch cache, Parameter: domain name CachingPrefetchCacheHit = "caching:prefetchHit" - // CachingResultCacheHit fires, if a query result was found in the cache, Parameter: domain name + // CachingResultCacheHit fires, if a query result was found in the cache CachingResultCacheHit = "caching:cacheHit" - // CachingResultCacheMiss fires, if a query result was not found in the cache, Parameter: domain name + // CachingResultCacheMiss fires, if a query result was not found in the cache CachingResultCacheMiss = "caching:cacheMiss" // CachingDomainsToPrefetchCountChanged fires, if a number of domains being prefetched changed, Parameter: new count diff --git a/resolver/blocking_resolver.go b/resolver/blocking_resolver.go index 1ec55e8c..c612df87 100644 --- a/resolver/blocking_resolver.go +++ b/resolver/blocking_resolver.go @@ -603,10 +603,11 @@ func (r *BlockingResolver) initFQDNIPCache() { identifiers = append(identifiers, identifier) } - r.fqdnIPCache = expirationcache.NewCache(expirationcache.WithCleanUpInterval[[]net.IP](defaultBlockingCleanUpInterval), - expirationcache.WithOnExpiredFn(func(key string) (val *[]net.IP, ttl time.Duration) { - return r.queryForFQIdentifierIPs(key) - })) + r.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](expirationcache.Options{ + CleanupInterval: defaultBlockingCleanUpInterval, + }, func(key string) (val *[]net.IP, ttl time.Duration) { + return r.queryForFQIdentifierIPs(key) + }) for _, identifier := range identifiers { if isFQDN(identifier) { diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 4a733b8c..01d90071 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -29,15 +29,9 @@ type CachingResolver struct { emitMetricEvents bool // disabled by Bootstrap - resultCache expirationcache.ExpiringCache[cacheValue] - prefetchingNameCache expirationcache.ExpiringCache[int] - redisClient *redis.Client -} + resultCache expirationcache.ExpiringCache[dns.Msg] -// cacheValue includes query answer and prefetch flag -type cacheValue struct { - resultMsg *dns.Msg - prefetch bool + redisClient *redis.Client } // NewCachingResolver creates a new resolver instance @@ -65,73 +59,76 @@ func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetri } func configureCaches(c *CachingResolver, cfg *config.CachingConfig) { - cleanupOption := expirationcache.WithCleanUpInterval[cacheValue](defaultCachingCleanUpInterval) - maxSizeOption := expirationcache.WithMaxSize[cacheValue](uint(cfg.MaxItemsCount)) + options := expirationcache.Options{ + CleanupInterval: defaultCachingCleanUpInterval, + MaxSize: uint(cfg.MaxItemsCount), + OnCacheHitFn: func(key string) { + c.publishMetricsIfEnabled(evt.CachingResultCacheHit, key) + }, + OnCacheMissFn: func(key string) { + c.publishMetricsIfEnabled(evt.CachingResultCacheMiss, key) + }, + OnAfterPutFn: func(newSize int) { + c.publishMetricsIfEnabled(evt.CachingResultCacheChanged, newSize) + }, + } if cfg.Prefetching { - c.prefetchingNameCache = expirationcache.NewCache( - expirationcache.WithCleanUpInterval[int](time.Minute), - expirationcache.WithMaxSize[int](uint(cfg.PrefetchMaxItemsCount)), - ) + prefetchingOptions := expirationcache.PrefetchingOptions[dns.Msg]{ + Options: options, + PrefetchExpires: time.Duration(cfg.PrefetchExpires), + PrefetchThreshold: cfg.PrefetchThreshold, + PrefetchMaxItemsCount: cfg.PrefetchMaxItemsCount, + ReloadFn: c.reloadCacheEntry, + OnPrefetchAfterPut: func(newSize int) { + c.publishMetricsIfEnabled(evt.CachingDomainsToPrefetchCountChanged, newSize) + }, + OnPrefetchEntryReloaded: func(key string) { + c.publishMetricsIfEnabled(evt.CachingDomainPrefetched, key) + }, + OnPrefetchCacheHit: func(key string) { + c.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, key) + }, + } - c.resultCache = expirationcache.NewCache( - cleanupOption, - maxSizeOption, - expirationcache.WithOnExpiredFn(c.onExpired), - ) + c.resultCache = expirationcache.NewPrefetchingCache(prefetchingOptions) } else { - c.resultCache = expirationcache.NewCache(cleanupOption, maxSizeOption) + c.resultCache = expirationcache.NewCache[dns.Msg](options) } } +func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*dns.Msg, time.Duration) { + qType, domainName := util.ExtractCacheKey(cacheKey) + logger := r.log() + + logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType) + + req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger) + response, err := r.next.Resolve(req) + + if err == nil { + if response.Res.Rcode == dns.RcodeSuccess { + return response.Res, r.adjustTTLs(response.Res.Answer) + } + } else { + util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err) + } + + return nil, 0 +} + func setupRedisCacheSubscriber(c *CachingResolver) { go func() { for rc := range c.redisClient.CacheChannel { if rc != nil { c.log().Debug("Received key from redis: ", rc.Key) ttl := c.adjustTTLs(rc.Response.Res.Answer) - c.putInCache(rc.Key, rc.Response, ttl, false, false) + c.putInCache(rc.Key, rc.Response, ttl, false) } } }() } -// check if domain was queried > threshold in the time window -func (r *CachingResolver) shouldPrefetch(cacheKey string) bool { - if r.cfg.PrefetchThreshold == 0 { - return true - } - - cnt, _ := r.prefetchingNameCache.Get(cacheKey) - - return cnt != nil && *cnt > r.cfg.PrefetchThreshold -} - -func (r *CachingResolver) onExpired(cacheKey string) (val *cacheValue, ttl time.Duration) { - qType, domainName := util.ExtractCacheKey(cacheKey) - - if r.shouldPrefetch(cacheKey) { - logger := r.log() - - logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType) - - req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger) - response, err := r.next.Resolve(req) - - if err == nil { - if response.Res.Rcode == dns.RcodeSuccess { - r.publishMetricsIfEnabled(evt.CachingDomainPrefetched, domainName) - - return &cacheValue{response.Res, true}, r.adjustTTLs(response.Res.Answer) - } - } else { - util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err) - } - } - - return nil, 0 -} - // LogConfig implements `config.Configurable`. func (r *CachingResolver) LogConfig(logger *logrus.Entry) { r.cfg.LogConfig(logger) @@ -155,23 +152,14 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain) logger := logger.WithField("domain", util.Obfuscate(domain)) - r.trackQueryDomainNameCount(domain, cacheKey, logger) - val, ttl := r.resultCache.Get(cacheKey) if val != nil { logger.Debug("domain is cached") - r.publishMetricsIfEnabled(evt.CachingResultCacheHit, domain) - - if val.prefetch { - // Hit from prefetch cache - r.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, domain) - } - - resp := val.resultMsg.Copy() + resp := val.Copy() resp.SetReply(request.Req) - resp.Rcode = val.resultMsg.Rcode + resp.Rcode = val.Rcode // Adjust TTL setTTLInCachedResponse(resp, ttl) @@ -183,14 +171,12 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil } - r.publishMetricsIfEnabled(evt.CachingResultCacheMiss, domain) - logger.WithField("next_resolver", Name(r.next)).Debug("not in cache: go to next resolver") response, err = r.next.Resolve(request) if err == nil { cacheTTL := r.adjustTTLs(response.Res.Answer) - r.putInCache(cacheKey, response, cacheTTL, false, true) + r.putInCache(cacheKey, response, cacheTTL, true) } } @@ -209,22 +195,6 @@ func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) { } } -func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, logger *logrus.Entry) { - if r.prefetchingNameCache != nil { - var domainCount int - if x, _ := r.prefetchingNameCache.Get(cacheKey); x != nil { - domainCount = *x - } - domainCount++ - r.prefetchingNameCache.Put(cacheKey, &domainCount, r.cfg.PrefetchExpires.ToDuration()) - totalCount := r.prefetchingNameCache.TotalCount() - - logger.Debugf("domain '%s' was requested %d times, "+ - "total cache size: %d", util.Obfuscate(domain), domainCount, totalCount) - r.publishMetricsIfEnabled(evt.CachingDomainsToPrefetchCountChanged, totalCount) - } -} - // removes EDNS OPT records from message func removeEdns0Extra(msg *dns.Msg) { if len(msg.Extra) > 0 { @@ -246,7 +216,7 @@ func shouldBeCached(msg *dns.Msg) bool { } func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, ttl time.Duration, - prefetch, publish bool, + publish bool, ) { respCopy := response.Res.Copy() @@ -255,16 +225,14 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, if response.Res.Rcode == dns.RcodeSuccess && shouldBeCached(response.Res) { // put value into cache - r.resultCache.Put(cacheKey, &cacheValue{respCopy, prefetch}, ttl) + r.resultCache.Put(cacheKey, respCopy, ttl) } else if response.Res.Rcode == dns.RcodeNameError { if r.cfg.CacheTimeNegative.IsAboveZero() { // put negative cache if result code is NXDOMAIN - r.resultCache.Put(cacheKey, &cacheValue{respCopy, prefetch}, r.cfg.CacheTimeNegative.ToDuration()) + r.resultCache.Put(cacheKey, respCopy, r.cfg.CacheTimeNegative.ToDuration()) } } - r.publishMetricsIfEnabled(evt.CachingResultCacheChanged, r.resultCache.TotalCount()) - if publish && r.redisClient != nil { res := *respCopy r.redisClient.PublishCache(cacheKey, &res) diff --git a/resolver/caching_resolver_test.go b/resolver/caching_resolver_test.go index 2da8c3d7..63a7333c 100644 --- a/resolver/caching_resolver_test.go +++ b/resolver/caching_resolver_test.go @@ -4,7 +4,6 @@ import ( "fmt" "time" - "github.com/0xERR0R/blocky/cache/expirationcache" "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/evt" . "github.com/0xERR0R/blocky/helpertest" @@ -81,18 +80,15 @@ var _ = Describe("CachingResolver", func() { // prepare resolver, set smaller caching times for testing prefetchThreshold := 5 configureCaches(sut, &sutConfig) - sut.resultCache = expirationcache.NewCache( - expirationcache.WithCleanUpInterval[cacheValue](100*time.Millisecond), - expirationcache.WithOnExpiredFn(sut.onExpired)) - domainPrefetched := make(chan string, 1) - prefetchHitDomain := make(chan string, 1) + domainPrefetched := make(chan bool, 1) + prefetchHitDomain := make(chan bool, 1) prefetchedCnt := make(chan int, 1) Expect(Bus().SubscribeOnce(CachingPrefetchCacheHit, func(domain string) { - prefetchHitDomain <- domain + prefetchHitDomain <- true })).Should(Succeed()) Expect(Bus().SubscribeOnce(CachingDomainPrefetched, func(domain string) { - domainPrefetched <- domain + domainPrefetched <- true })).Should(Succeed()) Expect(Bus().SubscribeOnce(CachingDomainsToPrefetchCountChanged, func(cnt int) { @@ -115,7 +111,7 @@ var _ = Describe("CachingResolver", func() { } // now is this domain prefetched - Eventually(domainPrefetched, "4s").Should(Receive(Equal("example.com"))) + Eventually(domainPrefetched, "10s").Should(Receive(Equal(true))) // and it should hit from prefetch cache Expect(sut.Resolve(newRequest("example.com.", A))). @@ -125,16 +121,7 @@ var _ = Describe("CachingResolver", func() { HaveReturnCode(dns.RcodeSuccess), BeDNSRecord("example.com.", A, "123.122.121.120"), HaveTTL(BeNumerically("<=", 2)))) - Eventually(prefetchHitDomain, "4s").Should(Receive(Equal("example.com"))) - }) - When("threshold is 0", func() { - BeforeEach(func() { - sutConfig.PrefetchThreshold = 0 - }) - - It("should always prefetch", func() { - Expect(sut.shouldPrefetch("domain.tld")).Should(BeTrue()) - }) + Eventually(prefetchHitDomain, "10s").Should(Receive(Equal(true))) }) }) When("caching with default values is enabled", func() { @@ -204,9 +191,9 @@ var _ = Describe("CachingResolver", func() { It("should cache response and use response's TTL", func() { By("first request", func() { - domain := make(chan string, 1) + domain := make(chan bool, 1) _ = Bus().SubscribeOnce(CachingResultCacheMiss, func(d string) { - domain <- d + domain <- true }) totalCacheCount := make(chan int, 1) @@ -223,15 +210,15 @@ var _ = Describe("CachingResolver", func() { Expect(m.Calls).Should(HaveLen(1)) - Expect(domain).Should(Receive(Equal("example.com"))) + Expect(domain).Should(Receive(Equal(true))) Expect(totalCacheCount).Should(Receive(Equal(1))) }) By("second request", func() { Eventually(func(g Gomega) { - domain := make(chan string, 1) + domain := make(chan bool, 1) _ = Bus().SubscribeOnce(CachingResultCacheHit, func(d string) { - domain <- d + domain <- true }) g.Expect(sut.Resolve(newRequest("example.com.", A))). @@ -246,7 +233,7 @@ var _ = Describe("CachingResolver", func() { // still one call to upstream g.Expect(m.Calls).Should(HaveLen(1)) - g.Expect(domain).Should(Receive(Equal("example.com"))) + g.Expect(domain).Should(Receive(Equal(true))) }, "1s").Should(Succeed()) }) }) diff --git a/resolver/client_names_resolver.go b/resolver/client_names_resolver.go index 37d30c8c..4dee1066 100644 --- a/resolver/client_names_resolver.go +++ b/resolver/client_names_resolver.go @@ -41,7 +41,9 @@ func NewClientNamesResolver( configurable: withConfig(&cfg), typed: withType("client_names"), - cache: expirationcache.NewCache(expirationcache.WithCleanUpInterval[[]string](time.Hour)), + cache: expirationcache.NewCache[[]string](expirationcache.Options{ + CleanupInterval: time.Hour, + }), externalResolver: r, }