diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 4e4b8822..e7925a24 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -121,23 +121,16 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) return nil, 0 } - cacheCopy, ttl := r.createCacheEntry(logger, response.Res) - if cacheCopy == nil || ttl == noCacheTTL { - return nil, 0 - } - - packed, err := cacheCopy.Pack() - if err != nil { - logger.WithError(err).WithError(err).Warn("response packing failed") - + ttl, res := r.createCacheEntry(logger, response.Res) + if ttl == noCacheTTL || len(res) == 0 { return nil, 0 } if r.redisClient != nil { - r.redisClient.PublishCache(cacheKey, cacheCopy) + r.redisClient.PublishCache(cacheKey, ttl, res) } - return &packed, util.ToTTLDuration(ttl) + return &res, util.ToTTLDuration(ttl) } func (r *CachingResolver) redisSubscriber(ctx context.Context) { @@ -153,7 +146,7 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) { dlogger.Debug("received from redis") - r.putInCache(dlogger, rc.Key, rc.Response) + // TODO: Add to cache } case <-ctx.Done(): @@ -203,11 +196,12 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) ( response, err = r.next.Resolve(ctx, request) if err == nil { - ttl := r.modifyResponseTTL(response.Res) - if ttl > noCacheTTL { - cacheCopy := r.putInCache(logger, cacheKey, response) - if cacheCopy != nil && r.redisClient != nil { - r.redisClient.PublishCache(cacheKey, cacheCopy) + ttl, cacheEntry := r.createCacheEntry(logger, response.Res) + if ttl != noCacheTTL && len(cacheEntry) > 0 { + r.resultCache.Put(cacheKey, &cacheEntry, util.ToTTLDuration(ttl)) + + if r.redisClient != nil { + r.redisClient.PublishCache(cacheKey, ttl, cacheEntry) } } } @@ -250,24 +244,6 @@ func isRequestCacheable(request *model.Request) bool { return true } -func (r *CachingResolver) putInCache(logger *logrus.Entry, cacheKey string, response *model.Response) *dns.Msg { - cacheCopy, ttl := r.createCacheEntry(logger, response.Res) - if cacheCopy == nil || ttl == noCacheTTL { - return nil - } - - packed, err := cacheCopy.Pack() - if err != nil { - logger.WithError(err).Warn("response packing failed") - - return nil - } - - r.resultCache.Put(cacheKey, &packed, util.ToTTLDuration(ttl)) - - return cacheCopy -} - func (r *CachingResolver) modifyResponseTTL(response *dns.Msg) uint32 { // if response is empty or negative, return negative cache time from config if len(response.Answer) == 0 || response.Rcode == dns.RcodeNameError { @@ -290,21 +266,28 @@ func (r *CachingResolver) modifyResponseTTL(response *dns.Msg) uint32 { return util.GetAnswerMinTTL(response) } -func (r *CachingResolver) createCacheEntry(logger *logrus.Entry, input *dns.Msg, -) (*dns.Msg, uint32) { - response := input.Copy() - - ttl := r.modifyResponseTTL(response) +func (r *CachingResolver) createCacheEntry(logger *logrus.Entry, input *dns.Msg) (uint32, []byte) { + ttl := r.modifyResponseTTL(input) if ttl == noCacheTTL { logger.Debug("response is not cacheable") - return nil, 0 + return 0, nil } - // don't cache any EDNS OPT records - util.RemoveEdns0Record(response) + internalMsg := input.Copy() + internalMsg.Compress = true - return response, ttl + // don't cache any EDNS OPT records + util.RemoveEdns0Record(internalMsg) + + packed, err := internalMsg.Pack() + if err != nil { + logger.WithError(err).Warn("response packing failed") + + return 0, nil + } + + return ttl, packed } func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{}) {