blocky/resolver/caching_resolver.go

238 lines
6.8 KiB
Go
Raw Normal View History

2020-01-12 18:23:35 +01:00
package resolver
import (
"fmt"
"time"
2021-08-25 22:06:34 +02:00
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/evt"
"github.com/0xERR0R/blocky/util"
2020-01-12 18:23:35 +01:00
"github.com/0xERR0R/go-cache"
2020-01-12 18:23:35 +01:00
"github.com/miekg/dns"
2021-01-16 22:24:05 +01:00
"github.com/sirupsen/logrus"
2020-01-12 18:23:35 +01:00
)
2021-02-26 21:44:53 +01:00
// CachingResolver caches answers from dns queries with their TTL time,
// to avoid external resolver calls for recurrent queries
2020-01-12 18:23:35 +01:00
type CachingResolver struct {
NextResolver
2021-01-19 21:52:24 +01:00
minCacheTimeSec, maxCacheTimeSec int
resultCache *cache.Cache
prefetchExpires time.Duration
prefetchThreshold int
2021-01-19 21:52:24 +01:00
prefetchingNameCache *cache.Cache
2020-01-12 18:23:35 +01:00
}
// cacheValue includes query answer and prefetch flag
type cacheValue struct {
answer []dns.RR
prefetch bool
}
const (
2021-01-16 22:24:05 +01:00
cacheTimeNegative = 30 * time.Minute
prefetchingNameCacheExpiration = 2 * time.Hour
prefetchingNameCountThreshold = 5
)
2020-01-12 18:23:35 +01:00
// NewCachingResolver creates a new resolver instance
func NewCachingResolver(cfg config.CachingConfig) ChainedResolver {
2021-01-16 22:24:05 +01:00
c := &CachingResolver{
minCacheTimeSec: 60 * cfg.MinCachingTime,
maxCacheTimeSec: 60 * cfg.MaxCachingTime,
resultCache: createQueryResultCache(&cfg),
2021-01-16 22:24:05 +01:00
}
if cfg.Prefetching {
configurePrefetching(c, &cfg)
2021-01-16 22:24:05 +01:00
}
return c
}
func createQueryResultCache(cfg *config.CachingConfig) *cache.Cache {
return cache.NewWithLRU(15*time.Minute, 15*time.Second, cfg.MaxItemsCount)
2021-01-16 22:24:05 +01:00
}
func configurePrefetching(c *CachingResolver, cfg *config.CachingConfig) {
c.prefetchExpires = prefetchingNameCacheExpiration
if cfg.PrefetchExpires > 0 {
c.prefetchExpires = time.Duration(cfg.PrefetchExpires) * time.Minute
2021-01-16 22:24:05 +01:00
}
c.prefetchThreshold = prefetchingNameCountThreshold
if cfg.PrefetchThreshold > 0 {
c.prefetchThreshold = cfg.PrefetchThreshold
}
c.prefetchingNameCache = cache.NewWithLRU(c.prefetchExpires, time.Minute, cfg.PrefetchMaxItemsCount)
c.resultCache.OnEvicted(func(key string, i interface{}) {
c.onEvicted(key)
})
2021-01-16 22:24:05 +01:00
}
// onEvicted is called if a DNS response in the cache is expired and was removed from cache
func (r *CachingResolver) onEvicted(cacheKey string) {
qType, domainName := util.ExtractCacheKey(cacheKey)
2021-01-16 22:24:05 +01:00
logger := logger("caching_resolver")
cnt, found := r.prefetchingNameCache.Get(cacheKey)
2021-01-16 22:24:05 +01:00
// check if domain was queried > threshold in the time window
if found && cnt.(int) > r.prefetchThreshold {
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), dns.TypeToString[qType])
2021-01-16 22:24:05 +01:00
req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger)
response, err := r.next.Resolve(req)
if err == nil {
r.putInCache(cacheKey, response, true)
2021-01-16 22:24:05 +01:00
2021-01-19 21:52:24 +01:00
evt.Bus().Publish(evt.CachingDomainPrefetched, domainName)
2021-01-16 22:24:05 +01:00
}
2020-11-18 22:31:05 +01:00
2021-01-19 21:52:24 +01:00
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
}
2021-01-16 22:24:05 +01:00
}
// Configuration returns a current resolver configuration
2020-01-12 18:23:35 +01:00
func (r *CachingResolver) Configuration() (result []string) {
if r.maxCacheTimeSec < 0 {
result = []string{"deactivated"}
return
}
result = append(result, fmt.Sprintf("minCacheTimeInSec = %d", r.minCacheTimeSec))
result = append(result, fmt.Sprintf("maxCacheTimeSec = %d", r.maxCacheTimeSec))
2021-01-16 22:24:05 +01:00
result = append(result, fmt.Sprintf("prefetching = %t", r.prefetchingNameCache != nil))
if r.prefetchingNameCache != nil {
result = append(result, fmt.Sprintf("prefetchExpires = %s", r.prefetchExpires))
result = append(result, fmt.Sprintf("prefetchThreshold = %d", r.prefetchThreshold))
}
result = append(result, fmt.Sprintf("cache items count = %d", r.resultCache.ItemCount()))
2020-01-12 18:23:35 +01:00
return
}
// Resolve checks if the current query result is already in the cache and returns it
2021-02-26 21:44:53 +01:00
// or delegates to the next resolver
//nolint:gocognit,funlen
2020-01-12 18:23:35 +01:00
func (r *CachingResolver) Resolve(request *Request) (response *Response, err error) {
logger := withPrefix(request.Log, "caching_resolver")
if r.maxCacheTimeSec < 0 {
logger.Debug("skip cache")
return r.next.Resolve(request)
}
2020-01-12 18:23:35 +01:00
resp := new(dns.Msg)
resp.SetReply(request.Req)
for _, question := range request.Req.Question {
domain := util.ExtractDomain(question)
cacheKey := util.GenerateCacheKey(question.Qtype, domain)
logger := logger.WithField("domain", util.Obfuscate(domain))
2020-01-12 18:23:35 +01:00
r.trackQueryDomainNameCount(domain, cacheKey, logger)
2021-01-16 22:24:05 +01:00
val, expiresAt, found := r.resultCache.GetWithExpiration(cacheKey)
2020-01-12 18:23:35 +01:00
if found {
logger.Debug("domain is cached")
2020-01-12 18:23:35 +01:00
evt.Bus().Publish(evt.CachingResultCacheHit, domain)
2020-11-18 22:31:05 +01:00
// calculate remaining TTL
remainingTTL := uint32(time.Until(expiresAt).Seconds())
2020-01-12 18:23:35 +01:00
v, ok := val.(cacheValue)
if ok {
if v.prefetch {
// Hit from prefetch cache
evt.Bus().Publish(evt.CachingPrefetchCacheHit, domain)
}
// Answer from successful request
resp.Answer = v.answer
for _, rr := range resp.Answer {
rr.Header().Ttl = remainingTTL
2020-01-12 18:23:35 +01:00
}
return &Response{Res: resp, RType: CACHED, Reason: "CACHED"}, nil
2020-01-12 18:23:35 +01:00
}
// Answer with response code != OK
resp.Rcode = val.(int)
2020-01-12 18:23:35 +01:00
return &Response{Res: resp, RType: CACHED, Reason: "CACHED NEGATIVE"}, nil
}
2020-11-18 22:31:05 +01:00
evt.Bus().Publish(evt.CachingResultCacheMiss, domain)
2020-01-12 18:23:35 +01:00
logger.WithField("next_resolver", Name(r.next)).Debug("not in cache: go to next resolver")
response, err = r.next.Resolve(request)
if err == nil {
r.putInCache(cacheKey, response, false)
2020-01-12 18:23:35 +01:00
}
}
return response, err
}
func (r *CachingResolver) trackQueryDomainNameCount(domain string, cacheKey string, logger *logrus.Entry) {
2021-01-16 22:24:05 +01:00
if r.prefetchingNameCache != nil {
var domainCount int
if x, found := r.prefetchingNameCache.Get(cacheKey); found {
2021-01-16 22:24:05 +01:00
domainCount = x.(int)
}
domainCount++
r.prefetchingNameCache.SetDefault(cacheKey, domainCount)
2021-01-16 22:24:05 +01:00
logger.Debugf("domain '%s' was requested %d times, "+
"total cache size: %d", util.Obfuscate(domain), domainCount, r.prefetchingNameCache.ItemCount())
2021-01-19 21:52:24 +01:00
evt.Bus().Publish(evt.CachingDomainsToPrefetchCountChanged, r.prefetchingNameCache.ItemCount())
2021-01-16 22:24:05 +01:00
}
}
func (r *CachingResolver) putInCache(cacheKey string, response *Response, prefetch bool) {
answer := response.Res.Answer
if response.Res.Rcode == dns.RcodeSuccess {
// put value into cache
r.resultCache.Set(cacheKey, cacheValue{answer, prefetch}, time.Duration(r.adjustTTLs(answer))*time.Second)
} else if response.Res.Rcode == dns.RcodeNameError {
// put return code if NXDOMAIN
r.resultCache.Set(cacheKey, response.Res.Rcode, cacheTimeNegative)
}
2021-01-19 21:52:24 +01:00
evt.Bus().Publish(evt.CachingResultCacheChanged, r.resultCache.ItemCount())
}
func (r *CachingResolver) adjustTTLs(answer []dns.RR) (maxTTL uint32) {
2020-01-17 21:53:15 +01:00
for _, a := range answer {
// if TTL < mitTTL -> adjust the value, set minTTL
if r.minCacheTimeSec > 0 {
if a.Header().Ttl < uint32(r.minCacheTimeSec) {
a.Header().Ttl = uint32(r.minCacheTimeSec)
}
}
if r.maxCacheTimeSec > 0 {
if a.Header().Ttl > uint32(r.maxCacheTimeSec) {
a.Header().Ttl = uint32(r.maxCacheTimeSec)
}
2020-01-17 21:53:15 +01:00
}
if maxTTL < a.Header().Ttl {
maxTTL = a.Header().Ttl
}
}
return
}