chore(refactor): refactor cache implementation (#1174)

* chore(refactor): refactor cache implementation

* chore: use atomic.Uint32 as prefetch names query count

Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>

---------

Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>
This commit is contained in:
Dimitri Herzog 2023-09-30 17:14:59 +02:00 committed by GitHub
parent d76740ea51
commit 497bd0d0fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 508 additions and 159 deletions

View File

@ -19,15 +19,18 @@ type element[T any] struct {
type ExpiringLRUCache[T any] struct {
cleanUpInterval time.Duration
preExpirationFn OnExpirationCallback[T]
onCacheHit OnCacheHitCallback
onCacheMiss OnCacheMissCallback
onAfterPut OnAfterPutCallback
lru *lru.Cache
}
type CacheOption[T any] func(c *ExpiringLRUCache[T])
func WithCleanUpInterval[T any](d time.Duration) CacheOption[T] {
return func(e *ExpiringLRUCache[T]) {
e.cleanUpInterval = d
}
type Options struct {
OnCacheHitFn OnCacheHitCallback
OnCacheMissFn OnCacheMissCallback
OnAfterPutFn OnAfterPutCallback
CleanupInterval time.Duration
MaxSize uint
}
// OnExpirationCallback will be called just before an element gets expired and will
@ -35,33 +38,56 @@ func WithCleanUpInterval[T any](d time.Duration) CacheOption[T] {
// element in the cache or nil to remove it
type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration)
func WithOnExpiredFn[T any](fn OnExpirationCallback[T]) CacheOption[T] {
return func(c *ExpiringLRUCache[T]) {
c.preExpirationFn = fn
}
// OnCacheHitCallback will be called on cache get if entry was found
type OnCacheHitCallback func(key string)
// OnCacheMissCallback will be called on cache get and entry was not found
type OnCacheMissCallback func(key string)
// OnAfterPutCallback will be called after put, receives new element count as parameter
type OnAfterPutCallback func(newSize int)
func NewCache[T any](options Options) *ExpiringLRUCache[T] {
return NewCacheWithOnExpired[T](options, nil)
}
func WithMaxSize[T any](size uint) CacheOption[T] {
return func(c *ExpiringLRUCache[T]) {
if size > 0 {
l, _ := lru.New(int(size))
c.lru = l
}
}
}
func NewCache[T any](options ...CacheOption[T]) *ExpiringLRUCache[T] {
func NewCacheWithOnExpired[T any](options Options,
onExpirationFn OnExpirationCallback[T],
) *ExpiringLRUCache[T] {
l, _ := lru.New(defaultSize)
c := &ExpiringLRUCache[T]{
cleanUpInterval: defaultCleanUpInterval,
preExpirationFn: func(key string) (val *T, ttl time.Duration) {
return nil, 0
},
lru: l,
onCacheHit: func(key string) {},
onCacheMiss: func(key string) {},
lru: l,
}
for _, opt := range options {
opt(c)
if options.CleanupInterval > 0 {
c.cleanUpInterval = options.CleanupInterval
}
if options.MaxSize > 0 {
l, _ := lru.New(int(options.MaxSize))
c.lru = l
}
if options.OnAfterPutFn != nil {
c.onAfterPut = options.OnAfterPutFn
}
if options.OnCacheHitFn != nil {
c.onCacheHit = options.OnCacheHitFn
}
if options.OnCacheMissFn != nil {
c.onCacheMiss = options.OnCacheMissFn
}
if onExpirationFn != nil {
c.preExpirationFn = onExpirationFn
}
go periodicCleanup(c)
@ -122,15 +148,23 @@ func (e *ExpiringLRUCache[T]) Put(key string, val *T, ttl time.Duration) {
val: val,
expiresEpochMs: expiresEpochMs,
})
if e.onAfterPut != nil {
e.onAfterPut(e.lru.Len())
}
}
func (e *ExpiringLRUCache[T]) Get(key string) (val *T, ttl time.Duration) {
el, found := e.lru.Get(key)
if found {
e.onCacheHit(key)
return el.(*element[T]).val, calculateRemainTTL(el.(*element[T]).expiresEpochMs)
}
e.onCacheMiss(key)
return nil, 0
}

View File

@ -11,11 +11,11 @@ var _ = Describe("Expiration cache", func() {
Describe("Basic operations", func() {
When("string cache was created", func() {
It("Initial cache should be empty", func() {
cache := NewCache[string]()
cache := NewCache[string](Options{})
Expect(cache.TotalCount()).Should(Equal(0))
})
It("Initial cache should not contain any elements", func() {
cache := NewCache[string]()
cache := NewCache[string](Options{})
val, expiration := cache.Get("key1")
Expect(val).Should(BeNil())
Expect(expiration).Should(Equal(time.Duration(0)))
@ -23,7 +23,7 @@ var _ = Describe("Expiration cache", func() {
})
When("Put new value with positive TTL", func() {
It("Should return the value before element expires", func() {
cache := NewCache(WithCleanUpInterval[string](100 * time.Millisecond))
cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond})
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
val, expiration := cache.Get("key1")
@ -33,7 +33,7 @@ var _ = Describe("Expiration cache", func() {
Expect(cache.TotalCount()).Should(Equal(1))
})
It("Should return nil after expiration", func() {
cache := NewCache(WithCleanUpInterval[string](100 * time.Millisecond))
cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond})
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
@ -52,7 +52,7 @@ var _ = Describe("Expiration cache", func() {
})
When("Put new value without expiration", func() {
It("Should not cache the value", func() {
cache := NewCache(WithCleanUpInterval[string](50 * time.Millisecond))
cache := NewCache[string](Options{CleanupInterval: 50 * time.Millisecond})
v := "x"
cache.Put("key1", &v, 0)
val, expiration := cache.Get("key1")
@ -63,7 +63,7 @@ var _ = Describe("Expiration cache", func() {
})
When("Put updated value", func() {
It("Should return updated value", func() {
cache := NewCache[string]()
cache := NewCache[string](Options{})
v1 := "v1"
v2 := "v2"
cache.Put("key1", &v1, 50*time.Millisecond)
@ -79,7 +79,7 @@ var _ = Describe("Expiration cache", func() {
})
When("Purging after usage", func() {
It("Should be empty after purge", func() {
cache := NewCache[string]()
cache := NewCache[string](Options{})
v1 := "y"
cache.Put("key1", &v1, time.Second)
@ -91,15 +91,62 @@ var _ = Describe("Expiration cache", func() {
})
})
})
Describe("Hook functions", func() {
When("Hook functions are defined", func() {
It("should call each hook function", func() {
onCacheHitChannel := make(chan string, 10)
onCacheMissChannel := make(chan string, 10)
onAfterPutChannel := make(chan int, 10)
cache := NewCache[string](Options{
OnCacheHitFn: func(key string) {
onCacheHitChannel <- key
},
OnCacheMissFn: func(key string) {
onCacheMissChannel <- key
},
OnAfterPutFn: func(newSize int) {
onAfterPutChannel <- newSize
},
})
By("Get non existing value", func() {
val, _ := cache.Get("notExists")
Expect(val).Should(BeNil())
Expect(onCacheMissChannel).Should(Receive(Equal("notExists")))
Expect(onCacheHitChannel).Should(Not(Receive()))
Expect(onAfterPutChannel).Should(Not(Receive()))
})
By("Put new cache entry", func() {
v1 := "v1"
cache.Put("key1", &v1, time.Second)
Expect(onCacheMissChannel).Should(Not(Receive()))
Expect(onCacheMissChannel).Should(Not(Receive()))
Expect(onAfterPutChannel).Should(Receive(Equal(1)))
})
By("Get existing value", func() {
val, _ := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
Expect(onCacheMissChannel).Should(Not(Receive()))
Expect(onCacheHitChannel).Should(Receive(Equal("key1")))
Expect(onAfterPutChannel).Should(Not(Receive()))
})
})
})
})
Describe("preExpiration function", func() {
When(" function is defined", func() {
When("function is defined", func() {
It("should update the value and TTL if function returns values", func() {
fn := func(key string) (val *string, ttl time.Duration) {
v2 := "v2"
return &v2, time.Second
}
cache := NewCache(WithOnExpiredFn(fn))
cache := NewCacheWithOnExpired[string](Options{}, fn)
v1 := "v1"
cache.Put("key1", &v1, 50*time.Millisecond)
@ -118,7 +165,7 @@ var _ = Describe("Expiration cache", func() {
return &v2, time.Second
}
cache := NewCache(WithOnExpiredFn(fn))
cache := NewCacheWithOnExpired[string](Options{}, fn)
v1 := "somval"
cache.Put("key1", &v1, time.Millisecond)
@ -139,7 +186,7 @@ var _ = Describe("Expiration cache", func() {
fn := func(key string) (val *string, ttl time.Duration) {
return nil, 0
}
cache := NewCache(WithCleanUpInterval[string](100*time.Millisecond), WithOnExpiredFn(fn))
cache := NewCacheWithOnExpired[string](Options{CleanupInterval: 100 * time.Microsecond}, fn)
v1 := "z"
cache.Put("key1", &v1, 50*time.Millisecond)
@ -152,7 +199,7 @@ var _ = Describe("Expiration cache", func() {
Describe("LRU behaviour", func() {
When("Defined max size is reached", func() {
It("should remove old elements", func() {
cache := NewCache(WithMaxSize[string](3))
cache := NewCache[string](Options{MaxSize: 3})
v1 := "val1"
v2 := "val2"

View File

@ -0,0 +1,127 @@
package expirationcache
import (
"sync/atomic"
"time"
)
type PrefetchingExpiringLRUCache[T any] struct {
cache ExpiringCache[cacheValue[T]]
prefetchingNameCache ExpiringCache[atomic.Uint32]
reloadFn ReloadEntryFn[T]
prefetchThreshold int
prefetchExpires time.Duration
onPrefetchEntryReloaded OnEntryReloadedCallback
onPrefetchCacheHit OnCacheHitCallback
}
type cacheValue[T any] struct {
element *T
prefetch bool
}
// OnEntryReloadedCallback will be called if a prefetched entry is reloaded
type OnEntryReloadedCallback func(key string)
// ReloadEntryFn reloads a prefetched entry by key
type ReloadEntryFn[T any] func(key string) (*T, time.Duration)
type PrefetchingOptions[T any] struct {
Options
ReloadFn func(cacheKey string) (*T, time.Duration)
PrefetchThreshold int
PrefetchExpires time.Duration
PrefetchMaxItemsCount int
OnPrefetchAfterPut OnAfterPutCallback
OnPrefetchEntryReloaded OnEntryReloadedCallback
OnPrefetchCacheHit OnCacheHitCallback
}
type PrefetchingCacheOption[T any] func(c *PrefetchingExpiringLRUCache[cacheValue[T]])
func NewPrefetchingCache[T any](options PrefetchingOptions[T]) *PrefetchingExpiringLRUCache[T] {
pc := &PrefetchingExpiringLRUCache[T]{
prefetchingNameCache: NewCache[atomic.Uint32](Options{
CleanupInterval: time.Minute,
MaxSize: uint(options.PrefetchMaxItemsCount),
OnAfterPutFn: options.OnPrefetchAfterPut,
}),
prefetchExpires: options.PrefetchExpires,
prefetchThreshold: options.PrefetchThreshold,
reloadFn: options.ReloadFn,
onPrefetchEntryReloaded: options.OnPrefetchEntryReloaded,
onPrefetchCacheHit: options.OnPrefetchCacheHit,
}
pc.cache = NewCacheWithOnExpired[cacheValue[T]](options.Options, pc.onExpired)
return pc
}
// check if a cache entry should be prefetched: was queried > threshold in the time window
func (e *PrefetchingExpiringLRUCache[T]) shouldPrefetch(cacheKey string) bool {
if e.prefetchThreshold == 0 {
return true
}
cnt, _ := e.prefetchingNameCache.Get(cacheKey)
return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold)
}
func (e *PrefetchingExpiringLRUCache[T]) onExpired(cacheKey string) (val *cacheValue[T], ttl time.Duration) {
if e.shouldPrefetch(cacheKey) {
loadedVal, ttl := e.reloadFn(cacheKey)
if loadedVal != nil {
if e.onPrefetchEntryReloaded != nil {
e.onPrefetchEntryReloaded(cacheKey)
}
return &cacheValue[T]{loadedVal, true}, ttl
}
}
return nil, 0
}
func (e *PrefetchingExpiringLRUCache[T]) trackCacheKeyQueryCount(cacheKey string) {
var x *atomic.Uint32
if x, _ = e.prefetchingNameCache.Get(cacheKey); x == nil {
x = &atomic.Uint32{}
}
x.Add(1)
e.prefetchingNameCache.Put(cacheKey, x, e.prefetchExpires)
}
func (e *PrefetchingExpiringLRUCache[T]) Put(key string, val *T, expiration time.Duration) {
e.cache.Put(key, &cacheValue[T]{element: val, prefetch: false}, expiration)
}
// Get returns the value of cached entry with remained TTL. If entry is not cached, returns nil
func (e *PrefetchingExpiringLRUCache[T]) Get(key string) (val *T, expiration time.Duration) {
e.trackCacheKeyQueryCount(key)
res, exp := e.cache.Get(key)
if res == nil {
return nil, exp
}
if e.onPrefetchCacheHit != nil && res.prefetch {
// Hit from prefetch cache
e.onPrefetchCacheHit(key)
}
return res.element, exp
}
// TotalCount returns the total count of valid (not expired) elements
func (e *PrefetchingExpiringLRUCache[T]) TotalCount() int {
return e.cache.TotalCount()
}
// Clear removes all cache entries
func (e *PrefetchingExpiringLRUCache[T]) Clear() {
e.cache.Clear()
}

View File

@ -0,0 +1,183 @@
package expirationcache
import (
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Prefetching expiration cache", func() {
Describe("Basic operations", func() {
When("string cache was created", func() {
It("Initial cache should be empty", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
Expect(cache.TotalCount()).Should(Equal(0))
})
It("Initial cache should not contain any elements", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
val, expiration := cache.Get("key1")
Expect(val).Should(BeNil())
Expect(expiration).Should(Equal(time.Duration(0)))
})
It("Should work as cache (basic operations)", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
val, expiration := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
Expect(cache.TotalCount()).Should(Equal(1))
cache.Clear()
Expect(cache.TotalCount()).Should(Equal(0))
})
})
Context("Prefetching", func() {
It("Should prefetch element", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
v := "v2"
return &v, 50 * time.Millisecond
},
})
By("put a value and get it again", func() {
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
val, expiration := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
Expect(cache.TotalCount()).Should(Equal(1))
})
By("Get value twice to trigger prefetching", func() {
val, _ := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
Eventually(func(g Gomega) {
val, _ := cache.Get("key1")
g.Expect(val).Should(HaveValue(Equal("v2")))
}).Should(Succeed())
})
})
It("Should not prefetch element", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
v := "v2"
return &v, 50 * time.Millisecond
},
})
By("put a value and get it again", func() {
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
val, expiration := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
Expect(cache.TotalCount()).Should(Equal(1))
})
By("Wait for expiration -> the entry should not be prefetched, threshold was not reached", func() {
Eventually(func(g Gomega) {
val, _ := cache.Get("key1")
g.Expect(val).Should(BeNil())
}, "5s", "500ms").Should(Succeed())
})
})
It("With default config (threshold = 0) should always prefetch", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
ReloadFn: func(cacheKey string) (*string, time.Duration) {
v := "v2"
return &v, 50 * time.Millisecond
},
})
By("put a value and get it again", func() {
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
val, expiration := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50))
})
By("Should return new prefetched value after expiration", func() {
Eventually(func(g Gomega) {
val, _ := cache.Get("key1")
g.Expect(val).Should(HaveValue(Equal("v2")))
}, "5s").Should(Succeed())
})
})
It("Should execute hook functions", func() {
onPrefetchAfterPutChannel := make(chan int, 10)
onPrefetchEntryReloaded := make(chan string, 10)
onnPrefetchCacheHit := make(chan string, 10)
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
v := "v2"
return &v, 50 * time.Millisecond
},
OnPrefetchAfterPut: func(newSize int) { onPrefetchAfterPutChannel <- newSize },
OnPrefetchEntryReloaded: func(key string) { onPrefetchEntryReloaded <- key },
OnPrefetchCacheHit: func(key string) { onnPrefetchCacheHit <- key },
})
By("put a value", func() {
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
Expect(onPrefetchAfterPutChannel).Should(Not(Receive()))
Expect(onPrefetchEntryReloaded).Should(Not(Receive()))
Expect(onnPrefetchCacheHit).Should(Not(Receive()))
})
By("get a value 3 times to trigger prefetching", func() {
// first get
cache.Get("key1")
Expect(onPrefetchAfterPutChannel).Should(Receive(Equal(1)))
Expect(onnPrefetchCacheHit).Should(Not(Receive()))
Expect(onPrefetchEntryReloaded).Should(Not(Receive()))
// secont get
val, _ := cache.Get("key1")
Expect(val).Should(HaveValue(Equal("v1")))
// third get -> this should trigger prefetching after expiration
cache.Get("key1")
// reload was executed
Eventually(onPrefetchEntryReloaded).Should(Receive(Equal("key1")))
Expect(onnPrefetchCacheHit).Should(Not(Receive()))
// has new value
Eventually(func(g Gomega) {
val, _ := cache.Get("key1")
g.Expect(val).Should(HaveValue(Equal("v2")))
}, "5s").Should(Succeed())
// prefetch hit
Eventually(onnPrefetchCacheHit).Should(Receive(Equal("key1")))
})
})
})
})
})

View File

@ -20,10 +20,10 @@ const (
// CachingPrefetchCacheHit fires if a query result was found in the prefetch cache, Parameter: domain name
CachingPrefetchCacheHit = "caching:prefetchHit"
// CachingResultCacheHit fires, if a query result was found in the cache, Parameter: domain name
// CachingResultCacheHit fires, if a query result was found in the cache
CachingResultCacheHit = "caching:cacheHit"
// CachingResultCacheMiss fires, if a query result was not found in the cache, Parameter: domain name
// CachingResultCacheMiss fires, if a query result was not found in the cache
CachingResultCacheMiss = "caching:cacheMiss"
// CachingDomainsToPrefetchCountChanged fires, if a number of domains being prefetched changed, Parameter: new count

View File

@ -603,10 +603,11 @@ func (r *BlockingResolver) initFQDNIPCache() {
identifiers = append(identifiers, identifier)
}
r.fqdnIPCache = expirationcache.NewCache(expirationcache.WithCleanUpInterval[[]net.IP](defaultBlockingCleanUpInterval),
expirationcache.WithOnExpiredFn(func(key string) (val *[]net.IP, ttl time.Duration) {
return r.queryForFQIdentifierIPs(key)
}))
r.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](expirationcache.Options{
CleanupInterval: defaultBlockingCleanUpInterval,
}, func(key string) (val *[]net.IP, ttl time.Duration) {
return r.queryForFQIdentifierIPs(key)
})
for _, identifier := range identifiers {
if isFQDN(identifier) {

View File

@ -29,15 +29,9 @@ type CachingResolver struct {
emitMetricEvents bool // disabled by Bootstrap
resultCache expirationcache.ExpiringCache[cacheValue]
prefetchingNameCache expirationcache.ExpiringCache[int]
redisClient *redis.Client
}
resultCache expirationcache.ExpiringCache[dns.Msg]
// cacheValue includes query answer and prefetch flag
type cacheValue struct {
resultMsg *dns.Msg
prefetch bool
redisClient *redis.Client
}
// NewCachingResolver creates a new resolver instance
@ -65,73 +59,76 @@ func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetri
}
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
cleanupOption := expirationcache.WithCleanUpInterval[cacheValue](defaultCachingCleanUpInterval)
maxSizeOption := expirationcache.WithMaxSize[cacheValue](uint(cfg.MaxItemsCount))
options := expirationcache.Options{
CleanupInterval: defaultCachingCleanUpInterval,
MaxSize: uint(cfg.MaxItemsCount),
OnCacheHitFn: func(key string) {
c.publishMetricsIfEnabled(evt.CachingResultCacheHit, key)
},
OnCacheMissFn: func(key string) {
c.publishMetricsIfEnabled(evt.CachingResultCacheMiss, key)
},
OnAfterPutFn: func(newSize int) {
c.publishMetricsIfEnabled(evt.CachingResultCacheChanged, newSize)
},
}
if cfg.Prefetching {
c.prefetchingNameCache = expirationcache.NewCache(
expirationcache.WithCleanUpInterval[int](time.Minute),
expirationcache.WithMaxSize[int](uint(cfg.PrefetchMaxItemsCount)),
)
prefetchingOptions := expirationcache.PrefetchingOptions[dns.Msg]{
Options: options,
PrefetchExpires: time.Duration(cfg.PrefetchExpires),
PrefetchThreshold: cfg.PrefetchThreshold,
PrefetchMaxItemsCount: cfg.PrefetchMaxItemsCount,
ReloadFn: c.reloadCacheEntry,
OnPrefetchAfterPut: func(newSize int) {
c.publishMetricsIfEnabled(evt.CachingDomainsToPrefetchCountChanged, newSize)
},
OnPrefetchEntryReloaded: func(key string) {
c.publishMetricsIfEnabled(evt.CachingDomainPrefetched, key)
},
OnPrefetchCacheHit: func(key string) {
c.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, key)
},
}
c.resultCache = expirationcache.NewCache(
cleanupOption,
maxSizeOption,
expirationcache.WithOnExpiredFn(c.onExpired),
)
c.resultCache = expirationcache.NewPrefetchingCache(prefetchingOptions)
} else {
c.resultCache = expirationcache.NewCache(cleanupOption, maxSizeOption)
c.resultCache = expirationcache.NewCache[dns.Msg](options)
}
}
func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*dns.Msg, time.Duration) {
qType, domainName := util.ExtractCacheKey(cacheKey)
logger := r.log()
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger)
response, err := r.next.Resolve(req)
if err == nil {
if response.Res.Rcode == dns.RcodeSuccess {
return response.Res, r.adjustTTLs(response.Res.Answer)
}
} else {
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
}
return nil, 0
}
func setupRedisCacheSubscriber(c *CachingResolver) {
go func() {
for rc := range c.redisClient.CacheChannel {
if rc != nil {
c.log().Debug("Received key from redis: ", rc.Key)
ttl := c.adjustTTLs(rc.Response.Res.Answer)
c.putInCache(rc.Key, rc.Response, ttl, false, false)
c.putInCache(rc.Key, rc.Response, ttl, false)
}
}
}()
}
// check if domain was queried > threshold in the time window
func (r *CachingResolver) shouldPrefetch(cacheKey string) bool {
if r.cfg.PrefetchThreshold == 0 {
return true
}
cnt, _ := r.prefetchingNameCache.Get(cacheKey)
return cnt != nil && *cnt > r.cfg.PrefetchThreshold
}
func (r *CachingResolver) onExpired(cacheKey string) (val *cacheValue, ttl time.Duration) {
qType, domainName := util.ExtractCacheKey(cacheKey)
if r.shouldPrefetch(cacheKey) {
logger := r.log()
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger)
response, err := r.next.Resolve(req)
if err == nil {
if response.Res.Rcode == dns.RcodeSuccess {
r.publishMetricsIfEnabled(evt.CachingDomainPrefetched, domainName)
return &cacheValue{response.Res, true}, r.adjustTTLs(response.Res.Answer)
}
} else {
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
}
}
return nil, 0
}
// LogConfig implements `config.Configurable`.
func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
r.cfg.LogConfig(logger)
@ -155,23 +152,14 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain)
logger := logger.WithField("domain", util.Obfuscate(domain))
r.trackQueryDomainNameCount(domain, cacheKey, logger)
val, ttl := r.resultCache.Get(cacheKey)
if val != nil {
logger.Debug("domain is cached")
r.publishMetricsIfEnabled(evt.CachingResultCacheHit, domain)
if val.prefetch {
// Hit from prefetch cache
r.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, domain)
}
resp := val.resultMsg.Copy()
resp := val.Copy()
resp.SetReply(request.Req)
resp.Rcode = val.resultMsg.Rcode
resp.Rcode = val.Rcode
// Adjust TTL
setTTLInCachedResponse(resp, ttl)
@ -183,14 +171,12 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
}
r.publishMetricsIfEnabled(evt.CachingResultCacheMiss, domain)
logger.WithField("next_resolver", Name(r.next)).Debug("not in cache: go to next resolver")
response, err = r.next.Resolve(request)
if err == nil {
cacheTTL := r.adjustTTLs(response.Res.Answer)
r.putInCache(cacheKey, response, cacheTTL, false, true)
r.putInCache(cacheKey, response, cacheTTL, true)
}
}
@ -209,22 +195,6 @@ func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) {
}
}
func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, logger *logrus.Entry) {
if r.prefetchingNameCache != nil {
var domainCount int
if x, _ := r.prefetchingNameCache.Get(cacheKey); x != nil {
domainCount = *x
}
domainCount++
r.prefetchingNameCache.Put(cacheKey, &domainCount, r.cfg.PrefetchExpires.ToDuration())
totalCount := r.prefetchingNameCache.TotalCount()
logger.Debugf("domain '%s' was requested %d times, "+
"total cache size: %d", util.Obfuscate(domain), domainCount, totalCount)
r.publishMetricsIfEnabled(evt.CachingDomainsToPrefetchCountChanged, totalCount)
}
}
// removes EDNS OPT records from message
func removeEdns0Extra(msg *dns.Msg) {
if len(msg.Extra) > 0 {
@ -246,7 +216,7 @@ func shouldBeCached(msg *dns.Msg) bool {
}
func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, ttl time.Duration,
prefetch, publish bool,
publish bool,
) {
respCopy := response.Res.Copy()
@ -255,16 +225,14 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response,
if response.Res.Rcode == dns.RcodeSuccess && shouldBeCached(response.Res) {
// put value into cache
r.resultCache.Put(cacheKey, &cacheValue{respCopy, prefetch}, ttl)
r.resultCache.Put(cacheKey, respCopy, ttl)
} else if response.Res.Rcode == dns.RcodeNameError {
if r.cfg.CacheTimeNegative.IsAboveZero() {
// put negative cache if result code is NXDOMAIN
r.resultCache.Put(cacheKey, &cacheValue{respCopy, prefetch}, r.cfg.CacheTimeNegative.ToDuration())
r.resultCache.Put(cacheKey, respCopy, r.cfg.CacheTimeNegative.ToDuration())
}
}
r.publishMetricsIfEnabled(evt.CachingResultCacheChanged, r.resultCache.TotalCount())
if publish && r.redisClient != nil {
res := *respCopy
r.redisClient.PublishCache(cacheKey, &res)

View File

@ -4,7 +4,6 @@ import (
"fmt"
"time"
"github.com/0xERR0R/blocky/cache/expirationcache"
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/evt"
. "github.com/0xERR0R/blocky/helpertest"
@ -81,18 +80,15 @@ var _ = Describe("CachingResolver", func() {
// prepare resolver, set smaller caching times for testing
prefetchThreshold := 5
configureCaches(sut, &sutConfig)
sut.resultCache = expirationcache.NewCache(
expirationcache.WithCleanUpInterval[cacheValue](100*time.Millisecond),
expirationcache.WithOnExpiredFn(sut.onExpired))
domainPrefetched := make(chan string, 1)
prefetchHitDomain := make(chan string, 1)
domainPrefetched := make(chan bool, 1)
prefetchHitDomain := make(chan bool, 1)
prefetchedCnt := make(chan int, 1)
Expect(Bus().SubscribeOnce(CachingPrefetchCacheHit, func(domain string) {
prefetchHitDomain <- domain
prefetchHitDomain <- true
})).Should(Succeed())
Expect(Bus().SubscribeOnce(CachingDomainPrefetched, func(domain string) {
domainPrefetched <- domain
domainPrefetched <- true
})).Should(Succeed())
Expect(Bus().SubscribeOnce(CachingDomainsToPrefetchCountChanged, func(cnt int) {
@ -115,7 +111,7 @@ var _ = Describe("CachingResolver", func() {
}
// now is this domain prefetched
Eventually(domainPrefetched, "4s").Should(Receive(Equal("example.com")))
Eventually(domainPrefetched, "10s").Should(Receive(Equal(true)))
// and it should hit from prefetch cache
Expect(sut.Resolve(newRequest("example.com.", A))).
@ -125,16 +121,7 @@ var _ = Describe("CachingResolver", func() {
HaveReturnCode(dns.RcodeSuccess),
BeDNSRecord("example.com.", A, "123.122.121.120"),
HaveTTL(BeNumerically("<=", 2))))
Eventually(prefetchHitDomain, "4s").Should(Receive(Equal("example.com")))
})
When("threshold is 0", func() {
BeforeEach(func() {
sutConfig.PrefetchThreshold = 0
})
It("should always prefetch", func() {
Expect(sut.shouldPrefetch("domain.tld")).Should(BeTrue())
})
Eventually(prefetchHitDomain, "10s").Should(Receive(Equal(true)))
})
})
When("caching with default values is enabled", func() {
@ -204,9 +191,9 @@ var _ = Describe("CachingResolver", func() {
It("should cache response and use response's TTL", func() {
By("first request", func() {
domain := make(chan string, 1)
domain := make(chan bool, 1)
_ = Bus().SubscribeOnce(CachingResultCacheMiss, func(d string) {
domain <- d
domain <- true
})
totalCacheCount := make(chan int, 1)
@ -223,15 +210,15 @@ var _ = Describe("CachingResolver", func() {
Expect(m.Calls).Should(HaveLen(1))
Expect(domain).Should(Receive(Equal("example.com")))
Expect(domain).Should(Receive(Equal(true)))
Expect(totalCacheCount).Should(Receive(Equal(1)))
})
By("second request", func() {
Eventually(func(g Gomega) {
domain := make(chan string, 1)
domain := make(chan bool, 1)
_ = Bus().SubscribeOnce(CachingResultCacheHit, func(d string) {
domain <- d
domain <- true
})
g.Expect(sut.Resolve(newRequest("example.com.", A))).
@ -246,7 +233,7 @@ var _ = Describe("CachingResolver", func() {
// still one call to upstream
g.Expect(m.Calls).Should(HaveLen(1))
g.Expect(domain).Should(Receive(Equal("example.com")))
g.Expect(domain).Should(Receive(Equal(true)))
}, "1s").Should(Succeed())
})
})

View File

@ -41,7 +41,9 @@ func NewClientNamesResolver(
configurable: withConfig(&cfg),
typed: withType("client_names"),
cache: expirationcache.NewCache(expirationcache.WithCleanUpInterval[[]string](time.Hour)),
cache: expirationcache.NewCache[[]string](expirationcache.Options{
CleanupInterval: time.Hour,
}),
externalResolver: r,
}