diff --git a/cache/expirationcache/expiration_cache.go b/cache/expirationcache/expiration_cache.go index ab37e313..400308c6 100644 --- a/cache/expirationcache/expiration_cache.go +++ b/cache/expirationcache/expiration_cache.go @@ -103,14 +103,14 @@ func periodicCleanup[T any](ctx context.Context, c *ExpiringLRUCache[T]) { for { select { case <-ticker.C: - c.cleanUp() + c.cleanUp(ctx) case <-ctx.Done(): return } } } -func (e *ExpiringLRUCache[T]) cleanUp() { +func (e *ExpiringLRUCache[T]) cleanUp(ctx context.Context) { var expiredKeys []string // check for expired items and collect expired keys @@ -126,7 +126,7 @@ func (e *ExpiringLRUCache[T]) cleanUp() { var keysToDelete []string for _, key := range expiredKeys { - newVal, newTTL := e.preExpirationFn(context.Background(), key) + newVal, newTTL := e.preExpirationFn(ctx, key) if newVal != nil { e.Put(key, newVal, newTTL) } else { diff --git a/cache/expirationcache/expiration_cache_test.go b/cache/expirationcache/expiration_cache_test.go index 3e3b8b0c..95bb46cb 100644 --- a/cache/expirationcache/expiration_cache_test.go +++ b/cache/expirationcache/expiration_cache_test.go @@ -181,7 +181,7 @@ var _ = Describe("Expiration cache", func() { time.Sleep(2 * time.Millisecond) // trigger cleanUp manually -> onExpiredFn will be executed, because element is expired - cache.cleanUp() + cache.cleanUp(ctx) // wait for expiration val, ttl := cache.Get("key1") diff --git a/redis/redis.go b/redis/redis.go index 7c205f19..82dab765 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -5,13 +5,11 @@ import ( "context" "encoding/json" "fmt" - "math" "strings" "time" "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/log" - "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" "github.com/go-redis/redis/v8" "github.com/google/uuid" @@ -29,10 +27,10 @@ const ( messageTypeEnable = 1 ) -// sendBuffer message -type bufferMessage struct { - Key string - Message *dns.Msg +type CacheEntry struct { + TTL uint32 + Key string + Entry []byte } // redis pubsub message @@ -43,12 +41,6 @@ type redisMessage struct { Client []byte `json:"c"` } -// CacheChannel message -type CacheMessage struct { - Key string - Response *model.Response -} - type EnabledMessage struct { State bool `json:"s"` Duration time.Duration `json:"d,omitempty"` @@ -61,8 +53,8 @@ type Client struct { client *redis.Client l *logrus.Entry id []byte - sendBuffer chan *bufferMessage - CacheChannel chan *CacheMessage + sendBuffer chan *CacheEntry + CacheChannel chan *CacheEntry EnabledChannel chan *EnabledMessage } @@ -100,55 +92,63 @@ func New(ctx context.Context, cfg *config.Redis) (*Client, error) { rdb := baseClient.WithContext(ctx) _, err := rdb.Ping(ctx).Result() - if err == nil { - var id []byte - - id, err = uuid.New().MarshalBinary() - if err == nil { - // construct client - res := &Client{ - config: cfg, - client: rdb, - l: log.PrefixedLog("redis"), - id: id, - sendBuffer: make(chan *bufferMessage, chanCap), - CacheChannel: make(chan *CacheMessage, chanCap), - EnabledChannel: make(chan *EnabledMessage, chanCap), - } - - // start channel handling go routine - err = res.startup(ctx) - - return res, err - } + if err != nil { + return nil, err } - return nil, err + var id []byte + + id, err = uuid.New().MarshalBinary() + if err != nil { + return nil, err + } + // construct client + res := &Client{ + config: cfg, + client: rdb, + l: log.PrefixedLog("redis"), + id: id, + sendBuffer: make(chan *CacheEntry, chanCap), + CacheChannel: make(chan *CacheEntry, chanCap), + EnabledChannel: make(chan *EnabledMessage, chanCap), + } + + // start channel handling go routine + err = res.startup(ctx) + if err != nil { + return nil, err + } + + return res, nil } -// PublishCache publish cache to redis async -func (c *Client) PublishCache(key string, message *dns.Msg) { - if len(key) > 0 && message != nil { - c.sendBuffer <- &bufferMessage{ - Key: key, - Message: message, +// PublishCache publish cache entry to redis if key and message are not empty and ttl > 0 +func (c *Client) PublishCache(key string, ttl uint32, message []byte) { + if len(key) > 0 && len(message) > 0 && ttl > 0 { + c.sendBuffer <- &CacheEntry{ + TTL: ttl, + Key: key, + Entry: message, } } } func (c *Client) PublishEnabled(ctx context.Context, state *EnabledMessage) { - binState, sErr := json.Marshal(state) - if sErr == nil { - binMsg, mErr := json.Marshal(redisMessage{ - Type: messageTypeEnable, - Message: binState, - Client: c.id, - }) - - if mErr == nil { - c.client.Publish(ctx, SyncChannelName, binMsg) - } + binState, err := json.Marshal(state) + if err != nil { + return } + + binMsg, err := json.Marshal(redisMessage{ + Type: messageTypeEnable, + Message: binState, + Client: c.id, + }) + if err != nil { + return + } + + c.client.Publish(ctx, SyncChannelName, binMsg) } // GetRedisCache reads the redis cache and publish it to the channel @@ -183,56 +183,53 @@ func (c *Client) startup(ctx context.Context) error { ps := c.client.Subscribe(ctx, SyncChannelName) _, err := ps.Receive(ctx) - if err == nil { - go func() { - for { - select { - // received message from subscription - case msg := <-ps.Channel(): - c.l.Debug("Received message: ", msg) - - if msg != nil && len(msg.Payload) > 0 { - // message is not empty - c.processReceivedMessage(ctx, msg) - } - // publish message from buffer - case s := <-c.sendBuffer: - c.publishMessageFromBuffer(ctx, s) - // context is done - case <-ctx.Done(): - c.client.Close() - - return - } - } - }() + if err != nil { + return err } - return err + go func() { + defer ps.Close() + defer c.client.Close() + + for { + select { + // received message from subscription + case msg := <-ps.Channel(): + c.l.Debug("Received message: ", msg) + + if msg != nil && len(msg.Payload) > 0 { + // message is not empty + c.processReceivedMessage(ctx, msg) + } + // publish message from buffer + case s := <-c.sendBuffer: + c.publishMessageFromBuffer(ctx, s) + // context is done + case <-ctx.Done(): + return + } + } + }() + + return nil } -func (c *Client) publishMessageFromBuffer(ctx context.Context, s *bufferMessage) { - origRes := s.Message - origRes.Compress = true - binRes, pErr := origRes.Pack() - - if pErr == nil { - binMsg, mErr := json.Marshal(redisMessage{ - Key: s.Key, - Type: messageTypeCache, - Message: binRes, - Client: c.id, - }) - - if mErr == nil { - c.client.Publish(ctx, SyncChannelName, binMsg) - } - - c.client.Set(ctx, - prefixKey(s.Key), - binRes, - c.getTTL(origRes)) +// publishMessageFromBuffer publishes a message from the buffer to the redis channel and stores it in the cache +func (c *Client) publishMessageFromBuffer(ctx context.Context, s *CacheEntry) { + psMsg, err := json.Marshal(redisMessage{ + Key: s.Key, + Type: messageTypeCache, + Message: s.Entry, + Client: c.id, + }) + if err == nil { + c.client.Publish(ctx, SyncChannelName, psMsg) } + + c.client.Set(ctx, + prefixKey(s.Key), + s.Entry, + util.ToTTLDuration(s.TTL)) } func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) { @@ -248,26 +245,24 @@ func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) if !bytes.Equal(rm.Client, c.id) { switch rm.Type { case messageTypeCache: - var cm *CacheMessage - - cm, err := convertMessage(&rm, 0) + cm, err := convertMessage(0, rm.Key, rm.Message) if err != nil { - c.l.Error("Processing CacheMessage error: ", err) + c.l.Error(err) return } - util.CtxSend(ctx, c.CacheChannel, cm) + util.CtxSend(ctx, c.CacheChannel, &cm) case messageTypeEnable: - var msg EnabledMessage + msg := new(EnabledMessage) - if err := json.Unmarshal(rm.Message, &msg); err != nil { + if err := json.Unmarshal(rm.Message, msg); err != nil { c.l.Error("Processing EnabledMessage error: ", err) return } - util.CtxSend(ctx, c.EnabledChannel, &msg) + util.CtxSend(ctx, c.EnabledChannel, msg) default: c.l.Warn("Unknown message type: ", rm.Type) } @@ -275,69 +270,48 @@ func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) } // getResponse returns model.Response for a key -func (c *Client) getResponse(ctx context.Context, key string) (*CacheMessage, error) { +func (c *Client) getResponse(ctx context.Context, key string) (*CacheEntry, error) { resp, err := c.client.Get(ctx, key).Result() - if err == nil { - var ttl time.Duration - ttl, err = c.client.TTL(ctx, key).Result() - - if err == nil { - var result *CacheMessage - - result, err = convertMessage(&redisMessage{ - Key: cleanKey(key), - Message: []byte(resp), - }, ttl) - if err != nil { - return nil, fmt.Errorf("conversion error: %w", err) - } - - return result, nil - } + if err != nil { + return nil, err } - return nil, err + ttl, err := c.client.TTL(ctx, key).Result() + if err != nil { + return nil, err + } + + result, err := convertMessage(ttl, cleanKey(key), []byte(resp)) + if err != nil { + return nil, err + } + + return &result, nil } // convertMessage converts redisMessage to CacheMessage -func convertMessage(message *redisMessage, ttl time.Duration) (*CacheMessage, error) { - msg := dns.Msg{} - - err := msg.Unpack(message.Message) - if err == nil { - if ttl > 0 { - for _, a := range msg.Answer { - a.Header().Ttl = uint32(ttl.Seconds()) - } - } - - res := &CacheMessage{ - Key: message.Key, - Response: &model.Response{ - RType: model.ResponseTypeCACHED, - Reason: cacheReason, - Res: &msg, - }, - } - - return res, nil +func convertMessage[T util.TTLInput](ttl T, key string, message []byte) (CacheEntry, error) { + res := CacheEntry{ + TTL: util.ToTTL(ttl), + Key: key, + Entry: message, } - return nil, err -} + packErr := fmt.Errorf("invalid message for key %s", key) + if len(message) == 0 { + return res, packErr + } -// getTTL of dns message or return defaultCacheTime if 0 -func (c *Client) getTTL(dns *dns.Msg) time.Duration { - ttl := uint32(math.MaxInt32) - for _, a := range dns.Answer { - ttl = min(ttl, a.Header().Ttl) + msg := new(dns.Msg) + if err := msg.Unpack(message); err != nil { + return res, packErr } if ttl == 0 { - return defaultCacheTime + res.TTL = util.GetAnswerMinTTL(msg) } - return time.Duration(ttl) * time.Second + return res, nil } // prefixKey with CacheStorePrefix diff --git a/redis/redis_test.go b/redis/redis_test.go index 5c7f530a..96f06a06 100644 --- a/redis/redis_test.go +++ b/redis/redis_test.go @@ -96,10 +96,12 @@ var _ = Describe("Redis client", func() { By("publish new message with TTL > 0", func() { res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123") - Expect(err).Should(Succeed()) - redisClient.PublishCache("example.com", res) + binRes, err := res.Pack() + Expect(err).Should(Succeed()) + + redisClient.PublishCache("example.com", 123, binRes) }) By("Database has one entry with correct TTL", func() { @@ -111,34 +113,6 @@ var _ = Describe("Redis client", func() { Expect(ttl.Seconds()).Should(BeNumerically("~", 123)) }) }) - - It("One new entry with default TTL should be persisted in the database", func(ctx context.Context) { - redisClient, err = New(ctx, redisConfig) - Expect(err).Should(Succeed()) - - By("Database is empty", func() { - Eventually(func() []string { - return redisServer.DB(redisConfig.Database).Keys() - }).Should(BeEmpty()) - }) - - By("publish new message with TTL = 0", func() { - res, err := util.NewMsgWithAnswer("example.com.", 0, dns.Type(dns.TypeA), "123.124.122.123") - - Expect(err).Should(Succeed()) - - redisClient.PublishCache("example.com", res) - }) - - By("Database has one entry with default TTL", func() { - Eventually(func() bool { - return redisServer.DB(redisConfig.Database).Exists(exampleComKey) - }).Should(BeTrue()) - - ttl := redisServer.DB(redisConfig.Database).TTL(exampleComKey) - Expect(ttl.Seconds()).Should(BeNumerically("~", defaultCacheTime.Seconds())) - }) - }) }) When("Redis client publishes 'enabled' message", func() { It("should propagate the message over redis", func(ctx context.Context) { @@ -222,7 +196,7 @@ var _ = Describe("Redis client", func() { rec := redisServer.Publish(SyncChannelName, string(binMsg)) Expect(rec).Should(Equal(1)) - Eventually(func() chan *CacheMessage { + Eventually(func() chan *CacheEntry { return redisClient.CacheChannel }).Should(HaveLen(lenE + 1)) }, SpecTimeout(time.Second*6)) @@ -255,7 +229,7 @@ var _ = Describe("Redis client", func() { return redisClient.EnabledChannel }).Should(HaveLen(lenE)) - Eventually(func() chan *CacheMessage { + Eventually(func() chan *CacheEntry { return redisClient.CacheChannel }).Should(HaveLen(lenC)) }, SpecTimeout(time.Second*6)) @@ -288,7 +262,7 @@ var _ = Describe("Redis client", func() { return redisClient.EnabledChannel }).Should(HaveLen(lenE)) - Eventually(func() chan *CacheMessage { + Eventually(func() chan *CacheEntry { return redisClient.CacheChannel }).Should(HaveLen(lenC)) }, SpecTimeout(time.Second*6)) @@ -312,13 +286,13 @@ var _ = Describe("Redis client", func() { }) By("Put valid data in Redis by publishing the cache entry", func() { - var res *dns.Msg - - res, err = util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123") - + res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123") Expect(err).Should(Succeed()) - redisClient.PublishCache("example.com", res) + binRes, err := res.Pack() + Expect(err).Should(Succeed()) + + redisClient.PublishCache("example.com", 123, binRes) }) By("Database has one entry now", func() { diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index e2a052f0..cf30cabf 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -2,9 +2,6 @@ package resolver import ( "context" - "fmt" - "math" - "sync/atomic" "time" "github.com/0xERR0R/blocky/cache/expirationcache" @@ -18,7 +15,9 @@ import ( "github.com/sirupsen/logrus" ) -const defaultCachingCleanUpInterval = 5 * time.Second +const ( + defaultCachingCleanUpInterval = 5 * time.Second +) // CachingResolver caches answers from dns queries with their TTL time, // to avoid external resolver calls for recurrent queries @@ -107,28 +106,29 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) { qType, domainName := util.ExtractCacheKey(cacheKey) ctx, logger := r.log(ctx) + logger = logger.WithField("domain", util.Obfuscate(domainName)) - logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType) + logger.Debugf("prefetching %s", qType) req := newRequest(dns.Fqdn(domainName), qType) + response, err := r.next.Resolve(ctx, req) + if err != nil { + logger.WithError(err).Warn("cache prefetch failed") - if err == nil { - if response.Res.Rcode == dns.RcodeSuccess { - packed, err := response.Res.Pack() - if err != nil { - logger.Error("unable to pack response", err) - - return nil, 0 - } - - return &packed, r.adjustTTLs(response.Res.Answer) - } - } else { - util.LogOnError(ctx, fmt.Sprintf("can't prefetch '%s' ", domainName), err) + return nil, 0 } - return nil, 0 + ttl, res := r.createCacheEntry(logger, response.Res) + if ttl == 0 || len(res) == 0 { + return nil, 0 + } + + if r.redisClient != nil { + r.redisClient.PublishCache(cacheKey, ttl, res) + } + + return &res, util.ToTTLDuration(ttl) } func (r *CachingResolver) redisSubscriber(ctx context.Context) { @@ -138,9 +138,11 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) { select { case rc := <-r.redisClient.CacheChannel: if rc != nil { - logger.Debug("Received key from redis: ", rc.Key) - ttl := r.adjustTTLs(rc.Response.Res.Answer) - r.putInCache(ctx, rc.Key, rc.Response, ttl, false) + _, domain := util.ExtractCacheKey(rc.Key) + + logger.WithField("domain", util.Obfuscate(domain)).Debug("received from redis") + + r.resultCache.Put(rc.Key, &rc.Entry, util.ToTTLDuration(rc.TTL)) } case <-ctx.Done(): @@ -172,63 +174,57 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) ( cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain) logger := logger.WithField("domain", util.Obfuscate(domain)) - val, ttl := r.getFromCache(logger, cacheKey) + cacheEntry := r.getFromCache(logger, cacheKey) - if val != nil { + if cacheEntry != nil { logger.Debug("domain is cached") - val.SetRcode(request.Req, val.Rcode) + cacheEntry.SetRcode(request.Req, cacheEntry.Rcode) - // Adjust TTL - setTTLInCachedResponse(val, ttl) - - if val.Rcode == dns.RcodeSuccess { - return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil + if cacheEntry.Rcode == dns.RcodeSuccess { + return &model.Response{Res: cacheEntry, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil } - return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil + return &model.Response{Res: cacheEntry, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil } logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver") - response, err = r.next.Resolve(ctx, request) + response, err = r.next.Resolve(ctx, request) if err == nil { - cacheTTL := r.adjustTTLs(response.Res.Answer) - r.putInCache(ctx, cacheKey, response, cacheTTL, true) + ttl, cacheEntry := r.createCacheEntry(logger, response.Res) + if ttl != 0 && len(cacheEntry) > 0 { + r.resultCache.Put(cacheKey, &cacheEntry, util.ToTTLDuration(ttl)) + + if r.redisClient != nil { + r.redisClient.PublishCache(cacheKey, ttl, cacheEntry) + } + } } } return response, err } -func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) (*dns.Msg, time.Duration) { - val, ttl := r.resultCache.Get(key) - if val == nil { - return nil, 0 +func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) *dns.Msg { + raw, ttl := r.resultCache.Get(key) + if raw == nil { + return nil } res := new(dns.Msg) - err := res.Unpack(*val) + err := res.Unpack(*raw) if err != nil { logger.Error("can't unpack cached entry. Cache malformed?", err) - return nil, 0 + return nil } - return res, ttl -} + // Adjust TTL + util.AdjustAnswerTTL(res, ttl) -func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) { - minTTL := uint32(math.MaxInt32) - // find smallest TTL first - for _, rr := range resp.Answer { - minTTL = min(minTTL, rr.Header().Ttl) - } - - for _, rr := range resp.Answer { - rr.Header().Ttl = rr.Header().Ttl - minTTL + uint32(ttl.Seconds()) - } + return res } // isRequestCacheable returns true if the request should be cached @@ -244,72 +240,50 @@ func isRequestCacheable(request *model.Request) bool { return true } -// isResponseCacheable returns true if the response is not truncated and its CD flag isn't set. -func isResponseCacheable(msg *dns.Msg) bool { - // we don't cache truncated responses and responses with CD flag - return !msg.Truncated && !msg.CheckingDisabled +func (r *CachingResolver) modifyResponseTTL(response *dns.Msg) uint32 { + // if response is empty or negative, return negative cache time from config + if len(response.Answer) == 0 || response.Rcode == dns.RcodeNameError { + return util.ToTTL(r.cfg.CacheTimeNegative) + } + + // if response is truncated or CD flag is set, return noCacheTTL since we don't cache these responses + if response.Truncated || response.CheckingDisabled { + return 0 + } + + // if response is not successful, return noCacheTTL since we don't cache these responses + if response.Rcode != dns.RcodeSuccess { + return 0 + } + + // adjust TTLs of all answers to match the configured min and max caching times + util.SetAnswerMinMaxTTL(response, r.cfg.MinCachingTime, r.cfg.MaxCachingTime) + + return util.GetAnswerMinTTL(response) } -func (r *CachingResolver) putInCache( - ctx context.Context, cacheKey string, response *model.Response, ttl time.Duration, publish bool, -) { - respCopy := response.Res.Copy() +func (r *CachingResolver) createCacheEntry(logger *logrus.Entry, input *dns.Msg) (uint32, []byte) { + ttl := r.modifyResponseTTL(input) + if ttl == 0 { + logger.Debug("response is not cacheable") + + return 0, nil + } + + internalMsg := input.Copy() + internalMsg.Compress = true // don't cache any EDNS OPT records - util.RemoveEdns0Record(respCopy) + util.RemoveEdns0Record(internalMsg) - packed, err := respCopy.Pack() - util.LogOnError(ctx, "error on packing", err) + packed, err := internalMsg.Pack() + if err != nil { + logger.WithError(err).Warn("response packing failed") - if err == nil { - if response.Res.Rcode == dns.RcodeSuccess && isResponseCacheable(response.Res) { - // put value into cache - r.resultCache.Put(cacheKey, &packed, ttl) - } else if response.Res.Rcode == dns.RcodeNameError { - if r.cfg.CacheTimeNegative.IsAboveZero() { - // put negative cache if result code is NXDOMAIN - r.resultCache.Put(cacheKey, &packed, r.cfg.CacheTimeNegative.ToDuration()) - } - } + return 0, nil } - if publish && r.redisClient != nil { - res := *respCopy - r.redisClient.PublishCache(cacheKey, &res) - } -} - -// adjustTTLs calculates and returns the min TTL (considers also the min and max cache time) -// for all records from answer or a negative cache time for empty answer -// adjust the TTL in the answer header accordingly -func (r *CachingResolver) adjustTTLs(answer []dns.RR) (ttl time.Duration) { - minTTL := uint32(math.MaxInt32) - - if len(answer) == 0 { - return r.cfg.CacheTimeNegative.ToDuration() - } - - for _, a := range answer { - // if TTL < mitTTL -> adjust the value, set minTTL - if r.cfg.MinCachingTime.IsAboveZero() { - if atomic.LoadUint32(&a.Header().Ttl) < r.cfg.MinCachingTime.SecondsU32() { - atomic.StoreUint32(&a.Header().Ttl, r.cfg.MinCachingTime.SecondsU32()) - } - } - - if r.cfg.MaxCachingTime.IsAboveZero() { - if atomic.LoadUint32(&a.Header().Ttl) > r.cfg.MaxCachingTime.SecondsU32() { - atomic.StoreUint32(&a.Header().Ttl, r.cfg.MaxCachingTime.SecondsU32()) - } - } - - headerTTL := atomic.LoadUint32(&a.Header().Ttl) - if minTTL > headerTTL { - minTTL = headerTTL - } - } - - return time.Duration(minTTL) * time.Second + return ttl, packed } func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{}) { diff --git a/resolver/caching_resolver_test.go b/resolver/caching_resolver_test.go index 5081ab54..20caa196 100644 --- a/resolver/caching_resolver_test.go +++ b/resolver/caching_resolver_test.go @@ -782,13 +782,13 @@ var _ = Describe("CachingResolver", func() { request := newRequest("example2.com.", A) domain := util.ExtractDomain(request.Req.Question[0]) cacheKey := util.GenerateCacheKey(A, domain) - redisMockMsg := &redis.CacheMessage{ - Key: cacheKey, - Response: &Response{ - RType: ResponseTypeCACHED, - Reason: "MOCK_REDIS", - Res: mockAnswer, - }, + binMsg, err := mockAnswer.Pack() + Expect(err).Should(Succeed()) + + redisMockMsg := &redis.CacheEntry{ + TTL: 123, + Key: cacheKey, + Entry: binMsg, } redisClient.CacheChannel <- redisMockMsg diff --git a/util/dns.go b/util/dns.go new file mode 100644 index 00000000..fcf635ec --- /dev/null +++ b/util/dns.go @@ -0,0 +1,144 @@ +package util + +import ( + "math" + "sync/atomic" + "time" + + "github.com/miekg/dns" +) + +// TTLInput is the input type for TTL values and consists of the following underlying types: +// int, uint, uint32, int64 +type TTLInput interface { + ~int | ~uint | ~uint32 | ~int64 +} + +// ToTTL converts the input to a TTL of seconds as uint32. +// +// If the input is of underlying type time.Duration, the value is converted to seconds. +// +// If the input is negative, the TTL is set to 0. +// +// If the input is greater than the maximum value of uint32, the TTL is set to math.MaxUint32. +func ToTTL[T TTLInput](input T) uint32 { + // fast return if the input is zero or below + if input <= 0 { + return 0 + } + + // fast return if the input is already of type uint32 + if ui32Type, ok := any(input).(uint32); ok { + return ui32Type + } + + // use int64 as the intermediate type for conversion + res := int64(input) + + // check if the input is of underlying type time.Duration + if durType, ok := any(input).(interface{ Seconds() float64 }); ok { + res = int64(durType.Seconds()) + } + + // check if the value is negative or greater than the maximum value of uint32 + if res < 0 { + // there is no negative TTL + return 0 + } else if res > math.MaxUint32 { + // since TTL is a 32-bit unsigned integer, the maximum value is math.MaxUint32 + return math.MaxUint32 + } + + // return the value as uint32 + return uint32(res) +} + +// ToTTLDuration converts the input to a time.Duration. +// +// The input is converted to a TTL of seconds as uint32 and then to a time.Duration. +func ToTTLDuration[T TTLInput](input T) time.Duration { + return time.Duration(ToTTL(input)) * time.Second +} + +// SetAnswerMinTTL sets the TTL of all answers in the message that are less than the specified minimum TTL to +// the minimum TTL. +func SetAnswerMinTTL[T TTLInput](msg *dns.Msg, min T) { + if minTTL := ToTTL(min); minTTL != 0 { + for _, answer := range msg.Answer { + if atomic.LoadUint32(&answer.Header().Ttl) < minTTL { + atomic.StoreUint32(&answer.Header().Ttl, minTTL) + } + } + } +} + +// SetAnswerMaxTTL sets the TTL of all answers in the message that are greater than the specified maximum TTL +// to the maximum TTL. +func SetAnswerMaxTTL[T TTLInput](msg *dns.Msg, max T) { + if maxTTL := ToTTL(max); maxTTL != 0 { + for _, answer := range msg.Answer { + if atomic.LoadUint32(&answer.Header().Ttl) > maxTTL { + atomic.StoreUint32(&answer.Header().Ttl, maxTTL) + } + } + } +} + +// SetAnswerMinMaxTTL sets the TTL of all answers in the message that are less than the specified minimum TTL +// to the minimum TTL and the TTL of all answers that are greater than the specified maximum TTL to the maximum TTL. +func SetAnswerMinMaxTTL[T, TT TTLInput](msg *dns.Msg, min T, max TT) { + minTTL := ToTTL(min) + maxTTL := ToTTL(max) + + switch { + case minTTL == 0 && maxTTL == 0: + // no TTL specified, fast return + return + case minTTL != 0 && maxTTL == 0: + // only minimum TTL specified + SetAnswerMinTTL(msg, min) + case minTTL == 0 && maxTTL != 0: + // only maximum TTL specified + SetAnswerMaxTTL(msg, max) + default: + // both minimum and maximum TTL specified + for _, answer := range msg.Answer { + headerTTL := atomic.LoadUint32(&answer.Header().Ttl) + if headerTTL < minTTL { + atomic.StoreUint32(&answer.Header().Ttl, minTTL) + } else if headerTTL > maxTTL { + atomic.StoreUint32(&answer.Header().Ttl, maxTTL) + } + } + } +} + +// GetMinAnswerTTL returns the lowest TTL of all answers in the message. +func GetAnswerMinTTL(msg *dns.Msg) uint32 { + var minTTL atomic.Uint32 + // initialize minTTL with the maximum value of uint32 + minTTL.Store(math.MaxUint32) + + for _, answer := range msg.Answer { + headerTTL := atomic.LoadUint32(&answer.Header().Ttl) + if headerTTL < minTTL.Load() { + minTTL.Store(headerTTL) + } + } + + return minTTL.Load() +} + +// AdjustAnswerTTL adjusts the TTL of all answers in the message by the difference between the lowest TTL +// and the answer's TTL plus the specified adjustment. +// +// If the adjustment is zero, the TTL is not changed. +func AdjustAnswerTTL[T TTLInput](msg *dns.Msg, adjustment T) { + adjustmentTTL := ToTTL(adjustment) + minTTL := GetAnswerMinTTL(msg) + + for _, answer := range msg.Answer { + headerTTL := atomic.LoadUint32(&answer.Header().Ttl) + atomic.StoreUint32(&answer.Header().Ttl, headerTTL-minTTL+adjustmentTTL) + } +} diff --git a/util/dns_test.go b/util/dns_test.go new file mode 100644 index 00000000..ef687488 --- /dev/null +++ b/util/dns_test.go @@ -0,0 +1,37 @@ +package util + +import ( + "math" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("EDNS0 utils", func() { + DescribeTable("ToTTL", + func(input interface{}, expected int) { + res := uint32(0) + switch it := input.(type) { + case uint32: + res = ToTTL(it) + case int: + res = ToTTL(it) + case int64: + res = ToTTL(it) + case time.Duration: + res = ToTTL(it) + default: + Fail("unsupported type") + } + + Expect(ToTTL(res)).Should(Equal(uint32(expected))) + }, + Entry("should return 0 for negative input", -1, 0), + Entry("should return uint32 for uint32 input", uint32(1), 1), + Entry("should return uint32 for int input", 1, 1), + Entry("should return uint32 for int64 input", int64(1), 1), + Entry("should return seconds for time.Duration input", time.Second, 1), + Entry("should return math.MaxUint32 for too large input", int64(math.MaxUint32)+1, math.MaxUint32), + ) +})