mirror of https://github.com/0xERR0R/blocky.git
280 lines
7.7 KiB
Go
280 lines
7.7 KiB
Go
package resolver
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/0xERR0R/blocky/cache/expirationcache"
|
|
"github.com/0xERR0R/blocky/config"
|
|
"github.com/0xERR0R/blocky/evt"
|
|
"github.com/0xERR0R/blocky/log"
|
|
"github.com/0xERR0R/blocky/model"
|
|
"github.com/0xERR0R/blocky/redis"
|
|
"github.com/0xERR0R/blocky/util"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const defaultCachingCleanUpInterval = 5 * time.Second
|
|
|
|
// CachingResolver caches answers from dns queries with their TTL time,
|
|
// to avoid external resolver calls for recurrent queries
|
|
type CachingResolver struct {
|
|
configurable[*config.CachingConfig]
|
|
NextResolver
|
|
typed
|
|
|
|
emitMetricEvents bool // disabled by Bootstrap
|
|
|
|
resultCache expirationcache.ExpiringCache[dns.Msg]
|
|
|
|
redisClient *redis.Client
|
|
}
|
|
|
|
// NewCachingResolver creates a new resolver instance
|
|
func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) *CachingResolver {
|
|
return newCachingResolver(cfg, redis, true)
|
|
}
|
|
|
|
func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetricEvents bool) *CachingResolver {
|
|
c := &CachingResolver{
|
|
configurable: withConfig(&cfg),
|
|
typed: withType("caching"),
|
|
|
|
redisClient: redis,
|
|
emitMetricEvents: emitMetricEvents,
|
|
}
|
|
|
|
configureCaches(c, &cfg)
|
|
|
|
if c.redisClient != nil {
|
|
setupRedisCacheSubscriber(c)
|
|
c.redisClient.GetRedisCache()
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
|
|
options := expirationcache.Options{
|
|
CleanupInterval: defaultCachingCleanUpInterval,
|
|
MaxSize: uint(cfg.MaxItemsCount),
|
|
OnCacheHitFn: func(key string) {
|
|
c.publishMetricsIfEnabled(evt.CachingResultCacheHit, key)
|
|
},
|
|
OnCacheMissFn: func(key string) {
|
|
c.publishMetricsIfEnabled(evt.CachingResultCacheMiss, key)
|
|
},
|
|
OnAfterPutFn: func(newSize int) {
|
|
c.publishMetricsIfEnabled(evt.CachingResultCacheChanged, newSize)
|
|
},
|
|
}
|
|
|
|
if cfg.Prefetching {
|
|
prefetchingOptions := expirationcache.PrefetchingOptions[dns.Msg]{
|
|
Options: options,
|
|
PrefetchExpires: time.Duration(cfg.PrefetchExpires),
|
|
PrefetchThreshold: cfg.PrefetchThreshold,
|
|
PrefetchMaxItemsCount: cfg.PrefetchMaxItemsCount,
|
|
ReloadFn: c.reloadCacheEntry,
|
|
OnPrefetchAfterPut: func(newSize int) {
|
|
c.publishMetricsIfEnabled(evt.CachingDomainsToPrefetchCountChanged, newSize)
|
|
},
|
|
OnPrefetchEntryReloaded: func(key string) {
|
|
c.publishMetricsIfEnabled(evt.CachingDomainPrefetched, key)
|
|
},
|
|
OnPrefetchCacheHit: func(key string) {
|
|
c.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, key)
|
|
},
|
|
}
|
|
|
|
c.resultCache = expirationcache.NewPrefetchingCache(prefetchingOptions)
|
|
} else {
|
|
c.resultCache = expirationcache.NewCache[dns.Msg](options)
|
|
}
|
|
}
|
|
|
|
func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*dns.Msg, time.Duration) {
|
|
qType, domainName := util.ExtractCacheKey(cacheKey)
|
|
logger := r.log()
|
|
|
|
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
|
|
|
|
req := newRequest(fmt.Sprintf("%s.", domainName), qType, logger)
|
|
response, err := r.next.Resolve(req)
|
|
|
|
if err == nil {
|
|
if response.Res.Rcode == dns.RcodeSuccess {
|
|
return response.Res, r.adjustTTLs(response.Res.Answer)
|
|
}
|
|
} else {
|
|
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
|
|
}
|
|
|
|
return nil, 0
|
|
}
|
|
|
|
func setupRedisCacheSubscriber(c *CachingResolver) {
|
|
go func() {
|
|
for rc := range c.redisClient.CacheChannel {
|
|
if rc != nil {
|
|
c.log().Debug("Received key from redis: ", rc.Key)
|
|
ttl := c.adjustTTLs(rc.Response.Res.Answer)
|
|
c.putInCache(rc.Key, rc.Response, ttl, false)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// LogConfig implements `config.Configurable`.
|
|
func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
|
|
r.cfg.LogConfig(logger)
|
|
|
|
logger.Infof("cache entries = %d", r.resultCache.TotalCount())
|
|
}
|
|
|
|
// Resolve checks if the current query result is already in the cache and returns it
|
|
// or delegates to the next resolver
|
|
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) {
|
|
logger := log.WithPrefix(request.Log, "caching_resolver")
|
|
|
|
if r.cfg.MaxCachingTime < 0 {
|
|
logger.Debug("skip cache")
|
|
|
|
return r.next.Resolve(request)
|
|
}
|
|
|
|
for _, question := range request.Req.Question {
|
|
domain := util.ExtractDomain(question)
|
|
cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain)
|
|
logger := logger.WithField("domain", util.Obfuscate(domain))
|
|
|
|
val, ttl := r.resultCache.Get(cacheKey)
|
|
|
|
if val != nil {
|
|
logger.Debug("domain is cached")
|
|
|
|
resp := val.Copy()
|
|
resp.SetReply(request.Req)
|
|
resp.Rcode = val.Rcode
|
|
|
|
// Adjust TTL
|
|
setTTLInCachedResponse(resp, ttl)
|
|
|
|
if resp.Rcode == dns.RcodeSuccess {
|
|
return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil
|
|
}
|
|
|
|
return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
|
|
}
|
|
|
|
logger.WithField("next_resolver", Name(r.next)).Debug("not in cache: go to next resolver")
|
|
response, err = r.next.Resolve(request)
|
|
|
|
if err == nil {
|
|
cacheTTL := r.adjustTTLs(response.Res.Answer)
|
|
r.putInCache(cacheKey, response, cacheTTL, true)
|
|
}
|
|
}
|
|
|
|
return response, err
|
|
}
|
|
|
|
func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) {
|
|
minTTL := uint32(math.MaxInt32)
|
|
// find smallest TTL first
|
|
for _, rr := range resp.Answer {
|
|
minTTL = min(minTTL, rr.Header().Ttl)
|
|
}
|
|
|
|
for _, rr := range resp.Answer {
|
|
rr.Header().Ttl = rr.Header().Ttl - minTTL + uint32(ttl.Seconds())
|
|
}
|
|
}
|
|
|
|
// removes EDNS OPT records from message
|
|
func removeEdns0Extra(msg *dns.Msg) {
|
|
if len(msg.Extra) > 0 {
|
|
extra := make([]dns.RR, 0, len(msg.Extra))
|
|
|
|
for _, rr := range msg.Extra {
|
|
if rr.Header().Rrtype != dns.TypeOPT {
|
|
extra = append(extra, rr)
|
|
}
|
|
}
|
|
|
|
msg.Extra = extra
|
|
}
|
|
}
|
|
|
|
func shouldBeCached(msg *dns.Msg) bool {
|
|
// we don't cache truncated responses and responses with CD flag
|
|
return !msg.Truncated && !msg.CheckingDisabled
|
|
}
|
|
|
|
func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, ttl time.Duration,
|
|
publish bool,
|
|
) {
|
|
respCopy := response.Res.Copy()
|
|
|
|
// don't cache any EDNS OPT records
|
|
removeEdns0Extra(respCopy)
|
|
|
|
if response.Res.Rcode == dns.RcodeSuccess && shouldBeCached(response.Res) {
|
|
// put value into cache
|
|
r.resultCache.Put(cacheKey, respCopy, ttl)
|
|
} else if response.Res.Rcode == dns.RcodeNameError {
|
|
if r.cfg.CacheTimeNegative.IsAboveZero() {
|
|
// put negative cache if result code is NXDOMAIN
|
|
r.resultCache.Put(cacheKey, respCopy, r.cfg.CacheTimeNegative.ToDuration())
|
|
}
|
|
}
|
|
|
|
if publish && r.redisClient != nil {
|
|
res := *respCopy
|
|
r.redisClient.PublishCache(cacheKey, &res)
|
|
}
|
|
}
|
|
|
|
// adjustTTLs calculates and returns the min TTL (considers also the min and max cache time)
|
|
// for all records from answer or a negative cache time for empty answer
|
|
// adjust the TTL in the answer header accordingly
|
|
func (r *CachingResolver) adjustTTLs(answer []dns.RR) (ttl time.Duration) {
|
|
minTTL := uint32(math.MaxInt32)
|
|
|
|
if len(answer) == 0 {
|
|
return r.cfg.CacheTimeNegative.ToDuration()
|
|
}
|
|
|
|
for _, a := range answer {
|
|
// if TTL < mitTTL -> adjust the value, set minTTL
|
|
if r.cfg.MinCachingTime.IsAboveZero() {
|
|
if atomic.LoadUint32(&a.Header().Ttl) < r.cfg.MinCachingTime.SecondsU32() {
|
|
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MinCachingTime.SecondsU32())
|
|
}
|
|
}
|
|
|
|
if r.cfg.MaxCachingTime.IsAboveZero() {
|
|
if atomic.LoadUint32(&a.Header().Ttl) > r.cfg.MaxCachingTime.SecondsU32() {
|
|
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MaxCachingTime.SecondsU32())
|
|
}
|
|
}
|
|
|
|
headerTTL := atomic.LoadUint32(&a.Header().Ttl)
|
|
if minTTL > headerTTL {
|
|
minTTL = headerTTL
|
|
}
|
|
}
|
|
|
|
return time.Duration(minTTL) * time.Second
|
|
}
|
|
|
|
func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{}) {
|
|
if r.emitMetricEvents {
|
|
evt.Bus().Publish(event, val)
|
|
}
|
|
}
|