From 2a5e443037c1bad450a72f0407f655ac294425ba Mon Sep 17 00:00:00 2001 From: Dimitri Herzog Date: Wed, 1 Mar 2023 22:59:33 +0100 Subject: [PATCH] WIP: quick&dirty expiration cache implementation --- .../expirationcache/redis_expiration_cache.go | 83 +++++++++++++++++++ resolver/caching_resolver.go | 39 ++++----- util/common.go | 8 +- 3 files changed, 102 insertions(+), 28 deletions(-) create mode 100644 cache/expirationcache/redis_expiration_cache.go diff --git a/cache/expirationcache/redis_expiration_cache.go b/cache/expirationcache/redis_expiration_cache.go new file mode 100644 index 00000000..8e1c4f07 --- /dev/null +++ b/cache/expirationcache/redis_expiration_cache.go @@ -0,0 +1,83 @@ +package expirationcache + +import ( + "context" + "fmt" + "github.com/go-redis/redis/v8" + "github.com/miekg/dns" + "strconv" + "time" +) + +type RedisCache struct { + rdb *redis.Client + name string +} + +func (r *RedisCache) Put(key string, val interface{}, expiration time.Duration) { + switch v := val.(type) { + case dns.Msg: + b, err := v.Pack() + if err != nil { + panic(err) + } + err = r.rdb.Set(context.Background(), r.name+":"+key, b, expiration).Err() + if err != nil { + panic(err) + } + case int: + err := r.rdb.Set(context.Background(), r.name+":"+key, v, expiration).Err() + if err != nil { + panic(err) + } + default: + fmt.Println("type unknown") + } + +} + +func (r *RedisCache) Get(key string) (val interface{}, expiration time.Duration) { + bytesVal, err := r.rdb.Get(context.Background(), r.name+":"+key).Bytes() + if err == redis.Nil { + return nil, 0 + } + if err != nil { + panic(err) + } + + exp := r.rdb.TTL(context.Background(), r.name+":"+key).Val() + + if len(bytesVal) <= 2 { + code, err := strconv.Atoi(string(bytesVal)) + if err != nil { + panic(err) + } + return code, exp + } + + msg := new(dns.Msg) + err = msg.Unpack(bytesVal) + if err != nil { + panic(err) + } + + return msg, exp +} + +func (r *RedisCache) TotalCount() int { + //TODO implement me + return 0 +} + +func (r *RedisCache) Clear() { + //TODO implement me + +} + +func NewRedisCache(rdb *redis.Client, name string) *RedisCache { + + return &RedisCache{ + rdb: rdb, + name: name, + } +} diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 7575b13d..0fd7965a 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -15,6 +15,8 @@ import ( "github.com/0xERR0R/blocky/redis" "github.com/0xERR0R/blocky/util" + rr "github.com/go-redis/redis/v8" + "github.com/miekg/dns" "github.com/sirupsen/logrus" ) @@ -35,8 +37,8 @@ type CachingResolver struct { redisEnabled bool } -// cacheValue includes query answer and prefetch flag -type cacheValue struct { +// CacheValue includes query answer and prefetch flag +type CacheValue struct { answer []dns.RR prefetch bool } @@ -62,9 +64,6 @@ func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) *CachingR } func configureCaches(c *CachingResolver, cfg *config.CachingConfig) { - cleanupOption := expirationcache.WithCleanUpInterval(defaultCachingCleanUpInterval) - maxSizeOption := expirationcache.WithMaxSize(uint(cfg.MaxItemsCount)) - if cfg.Prefetching { c.prefetchExpires = time.Duration(cfg.PrefetchExpires) @@ -74,15 +73,13 @@ func configureCaches(c *CachingResolver, cfg *config.CachingConfig) { expirationcache.WithCleanUpInterval(time.Minute), expirationcache.WithMaxSize(uint(cfg.PrefetchMaxItemsCount)), ) - - c.resultCache = expirationcache.NewCache( - cleanupOption, - maxSizeOption, - expirationcache.WithOnExpiredFn(c.onExpired), - ) - } else { - c.resultCache = expirationcache.NewCache(cleanupOption, maxSizeOption) } + + c.resultCache = expirationcache.NewRedisCache(rr.NewClient(&rr.Options{ + Addr: "localhost:6379", + Password: "", // no password set + DB: 0, // use default DB + }), "query_cache") } func setupRedisCacheSubscriber(c *CachingResolver) { @@ -124,7 +121,7 @@ func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time. if response.Res.Rcode == dns.RcodeSuccess { evt.Bus().Publish(evt.CachingDomainPrefetched, domainName) - return cacheValue{response.Res.Answer, true}, r.adjustTTLs(response.Res.Answer) + return CacheValue{response.Res.Answer, true}, r.adjustTTLs(response.Res.Answer) } } else { util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err) @@ -187,15 +184,15 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo evt.Bus().Publish(evt.CachingResultCacheHit, domain) - v, ok := val.(cacheValue) + v, ok := val.(*dns.Msg) if ok { - if v.prefetch { - // Hit from prefetch cache - evt.Bus().Publish(evt.CachingPrefetchCacheHit, domain) - } + //if v.prefetch { + // // Hit from prefetch cache + // evt.Bus().Publish(evt.CachingPrefetchCacheHit, domain) + //} // Answer from successful request - for _, rr := range v.answer { + for _, rr := range v.Answer { // make copy here since entries in cache can be modified by other goroutines (e.g. redis cache) cp := dns.Copy(rr) cp.Header().Ttl = uint32(ttl.Seconds()) @@ -245,7 +242,7 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, if response.Res.Rcode == dns.RcodeSuccess { // put value into cache - r.resultCache.Put(cacheKey, cacheValue{answer, prefetch}, r.adjustTTLs(answer)) + r.resultCache.Put(cacheKey, *response.Res, r.adjustTTLs(answer)) } else if response.Res.Rcode == dns.RcodeNameError { if r.cacheTimeNegative > 0 { // put return code if NXDOMAIN diff --git a/util/common.go b/util/common.go index bba5bc4c..0dd29248 100644 --- a/util/common.go +++ b/util/common.go @@ -165,13 +165,7 @@ func FatalOnError(message string, err error) { // GenerateCacheKey return cacheKey by query type/domain func GenerateCacheKey(qType dns.Type, qName string) string { - const qTypeLength = 2 - b := make([]byte, qTypeLength+len(qName)) - - binary.BigEndian.PutUint16(b, uint16(qType)) - copy(b[2:], strings.ToLower(qName)) - - return string(b) + return dns.TypeToString[uint16(qType)] + ":" + qName } // ExtractCacheKey return query type/domain from cacheKey