WIP: quick&dirty expiration cache implementation

This commit is contained in:
Dimitri Herzog 2023-03-01 22:59:33 +01:00
parent a695d38c99
commit 2a5e443037
3 changed files with 102 additions and 28 deletions

View File

@ -0,0 +1,83 @@
package expirationcache
import (
"context"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/miekg/dns"
"strconv"
"time"
)
type RedisCache struct {
rdb *redis.Client
name string
}
func (r *RedisCache) Put(key string, val interface{}, expiration time.Duration) {
switch v := val.(type) {
case dns.Msg:
b, err := v.Pack()
if err != nil {
panic(err)
}
err = r.rdb.Set(context.Background(), r.name+":"+key, b, expiration).Err()
if err != nil {
panic(err)
}
case int:
err := r.rdb.Set(context.Background(), r.name+":"+key, v, expiration).Err()
if err != nil {
panic(err)
}
default:
fmt.Println("type unknown")
}
}
func (r *RedisCache) Get(key string) (val interface{}, expiration time.Duration) {
bytesVal, err := r.rdb.Get(context.Background(), r.name+":"+key).Bytes()
if err == redis.Nil {
return nil, 0
}
if err != nil {
panic(err)
}
exp := r.rdb.TTL(context.Background(), r.name+":"+key).Val()
if len(bytesVal) <= 2 {
code, err := strconv.Atoi(string(bytesVal))
if err != nil {
panic(err)
}
return code, exp
}
msg := new(dns.Msg)
err = msg.Unpack(bytesVal)
if err != nil {
panic(err)
}
return msg, exp
}
func (r *RedisCache) TotalCount() int {
//TODO implement me
return 0
}
func (r *RedisCache) Clear() {
//TODO implement me
}
func NewRedisCache(rdb *redis.Client, name string) *RedisCache {
return &RedisCache{
rdb: rdb,
name: name,
}
}

View File

@ -15,6 +15,8 @@ import (
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/util"
rr "github.com/go-redis/redis/v8"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
@ -35,8 +37,8 @@ type CachingResolver struct {
redisEnabled bool
}
// cacheValue includes query answer and prefetch flag
type cacheValue struct {
// CacheValue includes query answer and prefetch flag
type CacheValue struct {
answer []dns.RR
prefetch bool
}
@ -62,9 +64,6 @@ func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) *CachingR
}
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
cleanupOption := expirationcache.WithCleanUpInterval(defaultCachingCleanUpInterval)
maxSizeOption := expirationcache.WithMaxSize(uint(cfg.MaxItemsCount))
if cfg.Prefetching {
c.prefetchExpires = time.Duration(cfg.PrefetchExpires)
@ -74,15 +73,13 @@ func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
expirationcache.WithCleanUpInterval(time.Minute),
expirationcache.WithMaxSize(uint(cfg.PrefetchMaxItemsCount)),
)
c.resultCache = expirationcache.NewCache(
cleanupOption,
maxSizeOption,
expirationcache.WithOnExpiredFn(c.onExpired),
)
} else {
c.resultCache = expirationcache.NewCache(cleanupOption, maxSizeOption)
}
c.resultCache = expirationcache.NewRedisCache(rr.NewClient(&rr.Options{
Addr: "localhost:6379",
Password: "", // no password set
DB: 0, // use default DB
}), "query_cache")
}
func setupRedisCacheSubscriber(c *CachingResolver) {
@ -124,7 +121,7 @@ func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.
if response.Res.Rcode == dns.RcodeSuccess {
evt.Bus().Publish(evt.CachingDomainPrefetched, domainName)
return cacheValue{response.Res.Answer, true}, r.adjustTTLs(response.Res.Answer)
return CacheValue{response.Res.Answer, true}, r.adjustTTLs(response.Res.Answer)
}
} else {
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
@ -187,15 +184,15 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
evt.Bus().Publish(evt.CachingResultCacheHit, domain)
v, ok := val.(cacheValue)
v, ok := val.(*dns.Msg)
if ok {
if v.prefetch {
// Hit from prefetch cache
evt.Bus().Publish(evt.CachingPrefetchCacheHit, domain)
}
//if v.prefetch {
// // Hit from prefetch cache
// evt.Bus().Publish(evt.CachingPrefetchCacheHit, domain)
//}
// Answer from successful request
for _, rr := range v.answer {
for _, rr := range v.Answer {
// make copy here since entries in cache can be modified by other goroutines (e.g. redis cache)
cp := dns.Copy(rr)
cp.Header().Ttl = uint32(ttl.Seconds())
@ -245,7 +242,7 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response,
if response.Res.Rcode == dns.RcodeSuccess {
// put value into cache
r.resultCache.Put(cacheKey, cacheValue{answer, prefetch}, r.adjustTTLs(answer))
r.resultCache.Put(cacheKey, *response.Res, r.adjustTTLs(answer))
} else if response.Res.Rcode == dns.RcodeNameError {
if r.cacheTimeNegative > 0 {
// put return code if NXDOMAIN

View File

@ -165,13 +165,7 @@ func FatalOnError(message string, err error) {
// GenerateCacheKey return cacheKey by query type/domain
func GenerateCacheKey(qType dns.Type, qName string) string {
const qTypeLength = 2
b := make([]byte, qTypeLength+len(qName))
binary.BigEndian.PutUint16(b, uint16(qType))
copy(b[2:], strings.ToLower(qName))
return string(b)
return dns.TypeToString[uint16(qType)] + ":" + qName
}
// ExtractCacheKey return query type/domain from cacheKey