mirror of https://github.com/0xERR0R/blocky.git
wip
This commit is contained in:
parent
86071445e5
commit
43e59a48e3
|
@ -2,8 +2,6 @@ package resolver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/cache/expirationcache"
|
||||
|
@ -17,7 +15,11 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const defaultCachingCleanUpInterval = 5 * time.Second
|
||||
const (
|
||||
defaultCachingCleanUpInterval = 5 * time.Second
|
||||
// noCacheTTL indicates that a response should not be cached
|
||||
noCacheTTL = time.Duration(-1)
|
||||
)
|
||||
|
||||
// CachingResolver caches answers from dns queries with their TTL time,
|
||||
// to avoid external resolver calls for recurrent queries
|
||||
|
@ -106,19 +108,36 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin
|
|||
func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) {
|
||||
qType, domainName := util.ExtractCacheKey(cacheKey)
|
||||
ctx, logger := r.log(ctx)
|
||||
logger = logger.WithField("domain", util.Obfuscate(domainName))
|
||||
|
||||
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
|
||||
logger.Debugf("prefetching %s", qType)
|
||||
|
||||
req := newRequest(dns.Fqdn(domainName), qType)
|
||||
|
||||
response, err := r.next.Resolve(ctx, req)
|
||||
if err != nil {
|
||||
logger.WithError(err).WithField("domain", domainName).Warn("cache prefetch failed")
|
||||
logger.WithError(err).Warn("cache prefetch failed")
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
return r.transformAndPublish(ctx, cacheKey, response, true)
|
||||
cacheCopy, ttl := r.createCacheEntry(logger, response.Res)
|
||||
if cacheCopy == nil || !cacheableTTL(ttl) {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
packed, err := cacheCopy.Pack()
|
||||
if err != nil {
|
||||
logger.WithError(err).WithError(err).Warn("response packing failed")
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
if r.redisClient != nil {
|
||||
r.redisClient.PublishCache(cacheKey, cacheCopy)
|
||||
}
|
||||
|
||||
return &packed, ttl
|
||||
}
|
||||
|
||||
func (r *CachingResolver) redisSubscriber(ctx context.Context) {
|
||||
|
@ -128,8 +147,13 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) {
|
|||
select {
|
||||
case rc := <-r.redisClient.CacheChannel:
|
||||
if rc != nil {
|
||||
logger.Debug("Received key from redis: ", rc.Key)
|
||||
r.putInCache(ctx, rc.Key, rc.Response, false)
|
||||
_, domain := util.ExtractCacheKey(rc.Key)
|
||||
|
||||
dlogger := logger.WithField("domain", util.Obfuscate(domain))
|
||||
|
||||
dlogger.Debug("received from redis")
|
||||
|
||||
r.putInCache(dlogger, rc.Key, rc.Response)
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
|
@ -161,62 +185,56 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (
|
|||
cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain)
|
||||
logger := logger.WithField("domain", util.Obfuscate(domain))
|
||||
|
||||
val, ttl := r.getFromCache(logger, cacheKey)
|
||||
cacheEntry := r.getFromCache(logger, cacheKey)
|
||||
|
||||
if val != nil {
|
||||
if cacheEntry != nil {
|
||||
logger.Debug("domain is cached")
|
||||
|
||||
val.SetRcode(request.Req, val.Rcode)
|
||||
cacheEntry.SetRcode(request.Req, cacheEntry.Rcode)
|
||||
|
||||
// Adjust TTL
|
||||
setTTLInCachedResponse(val, ttl)
|
||||
|
||||
if val.Rcode == dns.RcodeSuccess {
|
||||
return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil
|
||||
if cacheEntry.Rcode == dns.RcodeSuccess {
|
||||
return &model.Response{Res: cacheEntry, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil
|
||||
}
|
||||
|
||||
return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
|
||||
return &model.Response{Res: cacheEntry, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
|
||||
}
|
||||
|
||||
logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver")
|
||||
response, err = r.next.Resolve(ctx, request)
|
||||
|
||||
response, err = r.next.Resolve(ctx, request)
|
||||
if err == nil {
|
||||
r.putInCache(ctx, cacheKey, response, true)
|
||||
ttl := r.modifyResponseTTL(response.Res)
|
||||
if cacheableTTL(ttl) {
|
||||
cacheCopy := r.putInCache(logger, cacheKey, response)
|
||||
if cacheCopy != nil && r.redisClient != nil {
|
||||
r.redisClient.PublishCache(cacheKey, cacheCopy)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return response, err
|
||||
}
|
||||
|
||||
func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) (*dns.Msg, time.Duration) {
|
||||
val, ttl := r.resultCache.Get(key)
|
||||
if val == nil {
|
||||
return nil, 0
|
||||
func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) *dns.Msg {
|
||||
raw, ttl := r.resultCache.Get(key)
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
res := new(dns.Msg)
|
||||
|
||||
err := res.Unpack(*val)
|
||||
err := res.Unpack(*raw)
|
||||
if err != nil {
|
||||
logger.Error("can't unpack cached entry. Cache malformed?", err)
|
||||
|
||||
return nil, 0
|
||||
return nil
|
||||
}
|
||||
|
||||
return res, ttl
|
||||
}
|
||||
// Adjust TTL
|
||||
util.AdjustAnswerTTL(res, uint32(ttl.Seconds()))
|
||||
|
||||
func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) {
|
||||
minTTL := uint32(math.MaxInt32)
|
||||
// find smallest TTL first
|
||||
for _, rr := range resp.Answer {
|
||||
minTTL = min(minTTL, rr.Header().Ttl)
|
||||
}
|
||||
|
||||
for _, rr := range resp.Answer {
|
||||
rr.Header().Ttl = rr.Header().Ttl - minTTL + uint32(ttl.Seconds())
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// isRequestCacheable returns true if the request should be cached
|
||||
|
@ -232,99 +250,61 @@ func isRequestCacheable(request *model.Request) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// isResponseCacheable returns true if the response is not truncated and its CD flag isn't set.
|
||||
func isResponseCacheable(msg *dns.Msg) bool {
|
||||
// we don't cache truncated responses and responses with CD flag
|
||||
return !msg.Truncated && !msg.CheckingDisabled
|
||||
}
|
||||
|
||||
// transformAndPublish transforms the response to a byte array and publishes it to redis if publish is true
|
||||
// and redis is enabled. Returns the byte array and the TTL of the response
|
||||
func (r *CachingResolver) transformAndPublish(ctx context.Context, cacheKey string,
|
||||
response *model.Response, publish bool,
|
||||
) (*[]byte, time.Duration) {
|
||||
if response.Res.Rcode == dns.RcodeSuccess && !isResponseCacheable(response.Res) {
|
||||
return nil, 0
|
||||
func (r *CachingResolver) putInCache(logger *logrus.Entry, cacheKey string, response *model.Response) *dns.Msg {
|
||||
cacheCopy, ttl := r.createCacheEntry(logger, response.Res)
|
||||
if cacheCopy == nil || !cacheableTTL(ttl) {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, domainName := util.ExtractCacheKey(cacheKey)
|
||||
|
||||
_, logger := r.log(ctx)
|
||||
|
||||
respCopy := response.Res.Copy()
|
||||
|
||||
// don't cache any EDNS OPT records
|
||||
util.RemoveEdns0Record(respCopy)
|
||||
|
||||
packed, err := respCopy.Pack()
|
||||
packed, err := cacheCopy.Pack()
|
||||
if err != nil {
|
||||
logger.WithError(err).WithField("domain", domainName).Warn("cache prefetch failed")
|
||||
logger.WithError(err).Warn("response packing failed")
|
||||
|
||||
return nil, 0
|
||||
return nil
|
||||
}
|
||||
|
||||
ttl := time.Duration(0)
|
||||
r.resultCache.Put(cacheKey, &packed, ttl)
|
||||
|
||||
if response.Res.Rcode == dns.RcodeSuccess {
|
||||
ttl = r.adjustTTLs(response.Res.Answer)
|
||||
} else if response.Res.Rcode == dns.RcodeNameError {
|
||||
if r.cfg.CacheTimeNegative.IsAboveZero() {
|
||||
ttl = r.cfg.CacheTimeNegative.ToDuration()
|
||||
}
|
||||
}
|
||||
|
||||
if publish && r.redisClient != nil {
|
||||
res := *respCopy
|
||||
for _, rr := range res.Answer {
|
||||
rr.Header().Ttl = uint32(ttl.Seconds())
|
||||
}
|
||||
|
||||
r.redisClient.PublishCache(cacheKey, &res)
|
||||
}
|
||||
|
||||
return &packed, ttl
|
||||
return cacheCopy
|
||||
}
|
||||
|
||||
func (r *CachingResolver) putInCache(
|
||||
ctx context.Context, cacheKey string, response *model.Response, publish bool,
|
||||
) {
|
||||
res, ttl := r.transformAndPublish(ctx, cacheKey, response, publish)
|
||||
if res != nil {
|
||||
r.resultCache.Put(cacheKey, res, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
// adjustTTLs calculates and returns the min TTL (considers also the min and max cache time)
|
||||
// for all records from answer or a negative cache time for empty answer
|
||||
// adjust the TTL in the answer header accordingly
|
||||
func (r *CachingResolver) adjustTTLs(answer []dns.RR) (ttl time.Duration) {
|
||||
minTTL := uint32(math.MaxInt32)
|
||||
|
||||
if len(answer) == 0 {
|
||||
func (r *CachingResolver) modifyResponseTTL(response *dns.Msg) time.Duration {
|
||||
// if response is empty or negative, return negative cache time from config
|
||||
if len(response.Answer) == 0 || response.Rcode == dns.RcodeNameError {
|
||||
return r.cfg.CacheTimeNegative.ToDuration()
|
||||
}
|
||||
|
||||
for _, a := range answer {
|
||||
// if TTL < mitTTL -> adjust the value, set minTTL
|
||||
if r.cfg.MinCachingTime.IsAboveZero() {
|
||||
if atomic.LoadUint32(&a.Header().Ttl) < r.cfg.MinCachingTime.SecondsU32() {
|
||||
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MinCachingTime.SecondsU32())
|
||||
}
|
||||
}
|
||||
|
||||
if r.cfg.MaxCachingTime.IsAboveZero() {
|
||||
if atomic.LoadUint32(&a.Header().Ttl) > r.cfg.MaxCachingTime.SecondsU32() {
|
||||
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MaxCachingTime.SecondsU32())
|
||||
}
|
||||
}
|
||||
|
||||
headerTTL := atomic.LoadUint32(&a.Header().Ttl)
|
||||
if minTTL > headerTTL {
|
||||
minTTL = headerTTL
|
||||
}
|
||||
// if response is truncated or CD flag is set, return noCacheTTL since we don't cache these responses
|
||||
if response.Truncated || response.CheckingDisabled {
|
||||
return noCacheTTL
|
||||
}
|
||||
|
||||
return time.Duration(minTTL) * time.Second
|
||||
// if response is not successful, return noCacheTTL since we don't cache these responses
|
||||
if response.Rcode != dns.RcodeSuccess {
|
||||
return noCacheTTL
|
||||
}
|
||||
|
||||
// adjust TTLs of all answers to match the configured min and max caching times
|
||||
util.SetAnswerMinMaxTTL(response, r.cfg.MinCachingTime.SecondsU32(), r.cfg.MaxCachingTime.SecondsU32())
|
||||
|
||||
return time.Duration(util.GetAnswerMinTTL(response)) * time.Second
|
||||
}
|
||||
|
||||
func (r *CachingResolver) createCacheEntry(logger *logrus.Entry, input *dns.Msg,
|
||||
) (*dns.Msg, time.Duration) {
|
||||
response := input.Copy()
|
||||
|
||||
ttl := r.modifyResponseTTL(response)
|
||||
if !cacheableTTL(ttl) {
|
||||
logger.Debug("response is not cacheable")
|
||||
|
||||
return nil, noCacheTTL
|
||||
}
|
||||
|
||||
// don't cache any EDNS OPT records
|
||||
util.RemoveEdns0Record(response)
|
||||
|
||||
return response, ttl
|
||||
}
|
||||
|
||||
func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{}) {
|
||||
|
@ -339,3 +319,7 @@ func (r *CachingResolver) FlushCaches(ctx context.Context) {
|
|||
logger.Debug("flush caches")
|
||||
r.resultCache.Clear()
|
||||
}
|
||||
|
||||
func cacheableTTL(ttl time.Duration) bool {
|
||||
return ttl > 0
|
||||
}
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// SetAnswerMinTTL sets the TTL of all answers in the message that are less than the specified minimum TTL to
|
||||
// the minimum TTL.
|
||||
func SetAnswerMinTTL(msg *dns.Msg, minTTL uint32) {
|
||||
for _, answer := range msg.Answer {
|
||||
if atomic.LoadUint32(&answer.Header().Ttl) < minTTL {
|
||||
atomic.StoreUint32(&answer.Header().Ttl, minTTL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetAnswerMaxTTL sets the TTL of all answers in the message that are greater than the specified maximum TTL
|
||||
// to the maximum TTL.
|
||||
func SetAnswerMaxTTL(msg *dns.Msg, maxTTL uint32) {
|
||||
for _, answer := range msg.Answer {
|
||||
if atomic.LoadUint32(&answer.Header().Ttl) > maxTTL && maxTTL != 0 {
|
||||
atomic.StoreUint32(&answer.Header().Ttl, maxTTL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetAnswerMinMaxTTL sets the TTL of all answers in the message that are less than the specified minimum TTL
|
||||
// to the minimum TTL and the TTL of all answers that are greater than the specified maximum TTL to the maximum TTL.
|
||||
func SetAnswerMinMaxTTL(msg *dns.Msg, minTTL uint32, maxTTL uint32) {
|
||||
for _, answer := range msg.Answer {
|
||||
headerTTL := atomic.LoadUint32(&answer.Header().Ttl)
|
||||
if headerTTL < minTTL {
|
||||
atomic.StoreUint32(&answer.Header().Ttl, minTTL)
|
||||
} else if headerTTL > maxTTL && maxTTL != 0 {
|
||||
atomic.StoreUint32(&answer.Header().Ttl, maxTTL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetMinAnswerTTL returns the lowest TTL of all answers in the message.
|
||||
func GetAnswerMinTTL(msg *dns.Msg) uint32 {
|
||||
var minTTL atomic.Uint32
|
||||
// initialize minTTL with the maximum value of uint32
|
||||
minTTL.Store(math.MaxUint32)
|
||||
|
||||
for _, answer := range msg.Answer {
|
||||
headerTTL := atomic.LoadUint32(&answer.Header().Ttl)
|
||||
if headerTTL < minTTL.Load() {
|
||||
minTTL.Store(headerTTL)
|
||||
}
|
||||
}
|
||||
|
||||
return minTTL.Load()
|
||||
}
|
||||
|
||||
// AdjustAnswerTTL adjusts the TTL of all answers in the message by the difference between the lowest TTL
|
||||
// and the answer's TTL plus the specified adjustment.
|
||||
func AdjustAnswerTTL(msg *dns.Msg, adjustment uint32) {
|
||||
minTTL := GetAnswerMinTTL(msg)
|
||||
|
||||
for _, answer := range msg.Answer {
|
||||
headerTTL := atomic.LoadUint32(&answer.Header().Ttl)
|
||||
atomic.StoreUint32(&answer.Header().Ttl, headerTTL-minTTL+adjustment)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue