mirror of https://github.com/0xERR0R/blocky.git
Merge 7b2fcc2953
into 63468a7168
This commit is contained in:
commit
9898dbe1cd
|
@ -103,14 +103,14 @@ func periodicCleanup[T any](ctx context.Context, c *ExpiringLRUCache[T]) {
|
|||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.cleanUp()
|
||||
c.cleanUp(ctx)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ExpiringLRUCache[T]) cleanUp() {
|
||||
func (e *ExpiringLRUCache[T]) cleanUp(ctx context.Context) {
|
||||
var expiredKeys []string
|
||||
|
||||
// check for expired items and collect expired keys
|
||||
|
@ -126,7 +126,7 @@ func (e *ExpiringLRUCache[T]) cleanUp() {
|
|||
var keysToDelete []string
|
||||
|
||||
for _, key := range expiredKeys {
|
||||
newVal, newTTL := e.preExpirationFn(context.Background(), key)
|
||||
newVal, newTTL := e.preExpirationFn(ctx, key)
|
||||
if newVal != nil {
|
||||
e.Put(key, newVal, newTTL)
|
||||
} else {
|
||||
|
|
|
@ -181,7 +181,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
// trigger cleanUp manually -> onExpiredFn will be executed, because element is expired
|
||||
cache.cleanUp()
|
||||
cache.cleanUp(ctx)
|
||||
|
||||
// wait for expiration
|
||||
val, ttl := cache.Get("key1")
|
||||
|
|
284
redis/redis.go
284
redis/redis.go
|
@ -5,13 +5,11 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/google/uuid"
|
||||
|
@ -29,10 +27,10 @@ const (
|
|||
messageTypeEnable = 1
|
||||
)
|
||||
|
||||
// sendBuffer message
|
||||
type bufferMessage struct {
|
||||
Key string
|
||||
Message *dns.Msg
|
||||
type CacheEntry struct {
|
||||
TTL uint32
|
||||
Key string
|
||||
Entry []byte
|
||||
}
|
||||
|
||||
// redis pubsub message
|
||||
|
@ -43,12 +41,6 @@ type redisMessage struct {
|
|||
Client []byte `json:"c"`
|
||||
}
|
||||
|
||||
// CacheChannel message
|
||||
type CacheMessage struct {
|
||||
Key string
|
||||
Response *model.Response
|
||||
}
|
||||
|
||||
type EnabledMessage struct {
|
||||
State bool `json:"s"`
|
||||
Duration time.Duration `json:"d,omitempty"`
|
||||
|
@ -61,8 +53,8 @@ type Client struct {
|
|||
client *redis.Client
|
||||
l *logrus.Entry
|
||||
id []byte
|
||||
sendBuffer chan *bufferMessage
|
||||
CacheChannel chan *CacheMessage
|
||||
sendBuffer chan *CacheEntry
|
||||
CacheChannel chan *CacheEntry
|
||||
EnabledChannel chan *EnabledMessage
|
||||
}
|
||||
|
||||
|
@ -100,55 +92,63 @@ func New(ctx context.Context, cfg *config.Redis) (*Client, error) {
|
|||
rdb := baseClient.WithContext(ctx)
|
||||
|
||||
_, err := rdb.Ping(ctx).Result()
|
||||
if err == nil {
|
||||
var id []byte
|
||||
|
||||
id, err = uuid.New().MarshalBinary()
|
||||
if err == nil {
|
||||
// construct client
|
||||
res := &Client{
|
||||
config: cfg,
|
||||
client: rdb,
|
||||
l: log.PrefixedLog("redis"),
|
||||
id: id,
|
||||
sendBuffer: make(chan *bufferMessage, chanCap),
|
||||
CacheChannel: make(chan *CacheMessage, chanCap),
|
||||
EnabledChannel: make(chan *EnabledMessage, chanCap),
|
||||
}
|
||||
|
||||
// start channel handling go routine
|
||||
err = res.startup(ctx)
|
||||
|
||||
return res, err
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, err
|
||||
var id []byte
|
||||
|
||||
id, err = uuid.New().MarshalBinary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// construct client
|
||||
res := &Client{
|
||||
config: cfg,
|
||||
client: rdb,
|
||||
l: log.PrefixedLog("redis"),
|
||||
id: id,
|
||||
sendBuffer: make(chan *CacheEntry, chanCap),
|
||||
CacheChannel: make(chan *CacheEntry, chanCap),
|
||||
EnabledChannel: make(chan *EnabledMessage, chanCap),
|
||||
}
|
||||
|
||||
// start channel handling go routine
|
||||
err = res.startup(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// PublishCache publish cache to redis async
|
||||
func (c *Client) PublishCache(key string, message *dns.Msg) {
|
||||
if len(key) > 0 && message != nil {
|
||||
c.sendBuffer <- &bufferMessage{
|
||||
Key: key,
|
||||
Message: message,
|
||||
// PublishCache publish cache entry to redis if key and message are not empty and ttl > 0
|
||||
func (c *Client) PublishCache(key string, ttl uint32, message []byte) {
|
||||
if len(key) > 0 && len(message) > 0 && ttl > 0 {
|
||||
c.sendBuffer <- &CacheEntry{
|
||||
TTL: ttl,
|
||||
Key: key,
|
||||
Entry: message,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) PublishEnabled(ctx context.Context, state *EnabledMessage) {
|
||||
binState, sErr := json.Marshal(state)
|
||||
if sErr == nil {
|
||||
binMsg, mErr := json.Marshal(redisMessage{
|
||||
Type: messageTypeEnable,
|
||||
Message: binState,
|
||||
Client: c.id,
|
||||
})
|
||||
|
||||
if mErr == nil {
|
||||
c.client.Publish(ctx, SyncChannelName, binMsg)
|
||||
}
|
||||
binState, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
binMsg, err := json.Marshal(redisMessage{
|
||||
Type: messageTypeEnable,
|
||||
Message: binState,
|
||||
Client: c.id,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.client.Publish(ctx, SyncChannelName, binMsg)
|
||||
}
|
||||
|
||||
// GetRedisCache reads the redis cache and publish it to the channel
|
||||
|
@ -183,56 +183,53 @@ func (c *Client) startup(ctx context.Context) error {
|
|||
ps := c.client.Subscribe(ctx, SyncChannelName)
|
||||
|
||||
_, err := ps.Receive(ctx)
|
||||
if err == nil {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
// received message from subscription
|
||||
case msg := <-ps.Channel():
|
||||
c.l.Debug("Received message: ", msg)
|
||||
|
||||
if msg != nil && len(msg.Payload) > 0 {
|
||||
// message is not empty
|
||||
c.processReceivedMessage(ctx, msg)
|
||||
}
|
||||
// publish message from buffer
|
||||
case s := <-c.sendBuffer:
|
||||
c.publishMessageFromBuffer(ctx, s)
|
||||
// context is done
|
||||
case <-ctx.Done():
|
||||
c.client.Close()
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
go func() {
|
||||
defer ps.Close()
|
||||
defer c.client.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
// received message from subscription
|
||||
case msg := <-ps.Channel():
|
||||
c.l.Debug("Received message: ", msg)
|
||||
|
||||
if msg != nil && len(msg.Payload) > 0 {
|
||||
// message is not empty
|
||||
c.processReceivedMessage(ctx, msg)
|
||||
}
|
||||
// publish message from buffer
|
||||
case s := <-c.sendBuffer:
|
||||
c.publishMessageFromBuffer(ctx, s)
|
||||
// context is done
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) publishMessageFromBuffer(ctx context.Context, s *bufferMessage) {
|
||||
origRes := s.Message
|
||||
origRes.Compress = true
|
||||
binRes, pErr := origRes.Pack()
|
||||
|
||||
if pErr == nil {
|
||||
binMsg, mErr := json.Marshal(redisMessage{
|
||||
Key: s.Key,
|
||||
Type: messageTypeCache,
|
||||
Message: binRes,
|
||||
Client: c.id,
|
||||
})
|
||||
|
||||
if mErr == nil {
|
||||
c.client.Publish(ctx, SyncChannelName, binMsg)
|
||||
}
|
||||
|
||||
c.client.Set(ctx,
|
||||
prefixKey(s.Key),
|
||||
binRes,
|
||||
c.getTTL(origRes))
|
||||
// publishMessageFromBuffer publishes a message from the buffer to the redis channel and stores it in the cache
|
||||
func (c *Client) publishMessageFromBuffer(ctx context.Context, s *CacheEntry) {
|
||||
psMsg, err := json.Marshal(redisMessage{
|
||||
Key: s.Key,
|
||||
Type: messageTypeCache,
|
||||
Message: s.Entry,
|
||||
Client: c.id,
|
||||
})
|
||||
if err == nil {
|
||||
c.client.Publish(ctx, SyncChannelName, psMsg)
|
||||
}
|
||||
|
||||
c.client.Set(ctx,
|
||||
prefixKey(s.Key),
|
||||
s.Entry,
|
||||
util.ToTTLDuration(s.TTL))
|
||||
}
|
||||
|
||||
func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) {
|
||||
|
@ -248,26 +245,24 @@ func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message)
|
|||
if !bytes.Equal(rm.Client, c.id) {
|
||||
switch rm.Type {
|
||||
case messageTypeCache:
|
||||
var cm *CacheMessage
|
||||
|
||||
cm, err := convertMessage(&rm, 0)
|
||||
cm, err := convertMessage(0, rm.Key, rm.Message)
|
||||
if err != nil {
|
||||
c.l.Error("Processing CacheMessage error: ", err)
|
||||
c.l.Error(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
util.CtxSend(ctx, c.CacheChannel, cm)
|
||||
util.CtxSend(ctx, c.CacheChannel, &cm)
|
||||
case messageTypeEnable:
|
||||
var msg EnabledMessage
|
||||
msg := new(EnabledMessage)
|
||||
|
||||
if err := json.Unmarshal(rm.Message, &msg); err != nil {
|
||||
if err := json.Unmarshal(rm.Message, msg); err != nil {
|
||||
c.l.Error("Processing EnabledMessage error: ", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
util.CtxSend(ctx, c.EnabledChannel, &msg)
|
||||
util.CtxSend(ctx, c.EnabledChannel, msg)
|
||||
default:
|
||||
c.l.Warn("Unknown message type: ", rm.Type)
|
||||
}
|
||||
|
@ -275,69 +270,48 @@ func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message)
|
|||
}
|
||||
|
||||
// getResponse returns model.Response for a key
|
||||
func (c *Client) getResponse(ctx context.Context, key string) (*CacheMessage, error) {
|
||||
func (c *Client) getResponse(ctx context.Context, key string) (*CacheEntry, error) {
|
||||
resp, err := c.client.Get(ctx, key).Result()
|
||||
if err == nil {
|
||||
var ttl time.Duration
|
||||
ttl, err = c.client.TTL(ctx, key).Result()
|
||||
|
||||
if err == nil {
|
||||
var result *CacheMessage
|
||||
|
||||
result, err = convertMessage(&redisMessage{
|
||||
Key: cleanKey(key),
|
||||
Message: []byte(resp),
|
||||
}, ttl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("conversion error: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, err
|
||||
ttl, err := c.client.TTL(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := convertMessage(ttl, cleanKey(key), []byte(resp))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// convertMessage converts redisMessage to CacheMessage
|
||||
func convertMessage(message *redisMessage, ttl time.Duration) (*CacheMessage, error) {
|
||||
msg := dns.Msg{}
|
||||
|
||||
err := msg.Unpack(message.Message)
|
||||
if err == nil {
|
||||
if ttl > 0 {
|
||||
for _, a := range msg.Answer {
|
||||
a.Header().Ttl = uint32(ttl.Seconds())
|
||||
}
|
||||
}
|
||||
|
||||
res := &CacheMessage{
|
||||
Key: message.Key,
|
||||
Response: &model.Response{
|
||||
RType: model.ResponseTypeCACHED,
|
||||
Reason: cacheReason,
|
||||
Res: &msg,
|
||||
},
|
||||
}
|
||||
|
||||
return res, nil
|
||||
func convertMessage[T util.TTLInput](ttl T, key string, message []byte) (CacheEntry, error) {
|
||||
res := CacheEntry{
|
||||
TTL: util.ToTTL(ttl),
|
||||
Key: key,
|
||||
Entry: message,
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
packErr := fmt.Errorf("invalid message for key %s", key)
|
||||
if len(message) == 0 {
|
||||
return res, packErr
|
||||
}
|
||||
|
||||
// getTTL of dns message or return defaultCacheTime if 0
|
||||
func (c *Client) getTTL(dns *dns.Msg) time.Duration {
|
||||
ttl := uint32(math.MaxInt32)
|
||||
for _, a := range dns.Answer {
|
||||
ttl = min(ttl, a.Header().Ttl)
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(message); err != nil {
|
||||
return res, packErr
|
||||
}
|
||||
|
||||
if ttl == 0 {
|
||||
return defaultCacheTime
|
||||
res.TTL = util.GetAnswerMinTTL(msg)
|
||||
}
|
||||
|
||||
return time.Duration(ttl) * time.Second
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// prefixKey with CacheStorePrefix
|
||||
|
|
|
@ -96,10 +96,12 @@ var _ = Describe("Redis client", func() {
|
|||
|
||||
By("publish new message with TTL > 0", func() {
|
||||
res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
redisClient.PublishCache("example.com", res)
|
||||
binRes, err := res.Pack()
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
redisClient.PublishCache("example.com", 123, binRes)
|
||||
})
|
||||
|
||||
By("Database has one entry with correct TTL", func() {
|
||||
|
@ -111,34 +113,6 @@ var _ = Describe("Redis client", func() {
|
|||
Expect(ttl.Seconds()).Should(BeNumerically("~", 123))
|
||||
})
|
||||
})
|
||||
|
||||
It("One new entry with default TTL should be persisted in the database", func(ctx context.Context) {
|
||||
redisClient, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
By("Database is empty", func() {
|
||||
Eventually(func() []string {
|
||||
return redisServer.DB(redisConfig.Database).Keys()
|
||||
}).Should(BeEmpty())
|
||||
})
|
||||
|
||||
By("publish new message with TTL = 0", func() {
|
||||
res, err := util.NewMsgWithAnswer("example.com.", 0, dns.Type(dns.TypeA), "123.124.122.123")
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
redisClient.PublishCache("example.com", res)
|
||||
})
|
||||
|
||||
By("Database has one entry with default TTL", func() {
|
||||
Eventually(func() bool {
|
||||
return redisServer.DB(redisConfig.Database).Exists(exampleComKey)
|
||||
}).Should(BeTrue())
|
||||
|
||||
ttl := redisServer.DB(redisConfig.Database).TTL(exampleComKey)
|
||||
Expect(ttl.Seconds()).Should(BeNumerically("~", defaultCacheTime.Seconds()))
|
||||
})
|
||||
})
|
||||
})
|
||||
When("Redis client publishes 'enabled' message", func() {
|
||||
It("should propagate the message over redis", func(ctx context.Context) {
|
||||
|
@ -222,7 +196,7 @@ var _ = Describe("Redis client", func() {
|
|||
rec := redisServer.Publish(SyncChannelName, string(binMsg))
|
||||
Expect(rec).Should(Equal(1))
|
||||
|
||||
Eventually(func() chan *CacheMessage {
|
||||
Eventually(func() chan *CacheEntry {
|
||||
return redisClient.CacheChannel
|
||||
}).Should(HaveLen(lenE + 1))
|
||||
}, SpecTimeout(time.Second*6))
|
||||
|
@ -255,7 +229,7 @@ var _ = Describe("Redis client", func() {
|
|||
return redisClient.EnabledChannel
|
||||
}).Should(HaveLen(lenE))
|
||||
|
||||
Eventually(func() chan *CacheMessage {
|
||||
Eventually(func() chan *CacheEntry {
|
||||
return redisClient.CacheChannel
|
||||
}).Should(HaveLen(lenC))
|
||||
}, SpecTimeout(time.Second*6))
|
||||
|
@ -288,7 +262,7 @@ var _ = Describe("Redis client", func() {
|
|||
return redisClient.EnabledChannel
|
||||
}).Should(HaveLen(lenE))
|
||||
|
||||
Eventually(func() chan *CacheMessage {
|
||||
Eventually(func() chan *CacheEntry {
|
||||
return redisClient.CacheChannel
|
||||
}).Should(HaveLen(lenC))
|
||||
}, SpecTimeout(time.Second*6))
|
||||
|
@ -312,13 +286,13 @@ var _ = Describe("Redis client", func() {
|
|||
})
|
||||
|
||||
By("Put valid data in Redis by publishing the cache entry", func() {
|
||||
var res *dns.Msg
|
||||
|
||||
res, err = util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
|
||||
|
||||
res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
redisClient.PublishCache("example.com", res)
|
||||
binRes, err := res.Pack()
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
redisClient.PublishCache("example.com", 123, binRes)
|
||||
})
|
||||
|
||||
By("Database has one entry now", func() {
|
||||
|
|
|
@ -2,9 +2,6 @@ package resolver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/cache/expirationcache"
|
||||
|
@ -18,7 +15,9 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const defaultCachingCleanUpInterval = 5 * time.Second
|
||||
const (
|
||||
defaultCachingCleanUpInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
// CachingResolver caches answers from dns queries with their TTL time,
|
||||
// to avoid external resolver calls for recurrent queries
|
||||
|
@ -107,28 +106,29 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin
|
|||
func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) {
|
||||
qType, domainName := util.ExtractCacheKey(cacheKey)
|
||||
ctx, logger := r.log(ctx)
|
||||
logger = logger.WithField("domain", util.Obfuscate(domainName))
|
||||
|
||||
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
|
||||
logger.Debugf("prefetching %s", qType)
|
||||
|
||||
req := newRequest(dns.Fqdn(domainName), qType)
|
||||
|
||||
response, err := r.next.Resolve(ctx, req)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("cache prefetch failed")
|
||||
|
||||
if err == nil {
|
||||
if response.Res.Rcode == dns.RcodeSuccess {
|
||||
packed, err := response.Res.Pack()
|
||||
if err != nil {
|
||||
logger.Error("unable to pack response", err)
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
return &packed, r.adjustTTLs(response.Res.Answer)
|
||||
}
|
||||
} else {
|
||||
util.LogOnError(ctx, fmt.Sprintf("can't prefetch '%s' ", domainName), err)
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
return nil, 0
|
||||
ttl, res := r.createCacheEntry(logger, response.Res)
|
||||
if ttl == 0 || len(res) == 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
if r.redisClient != nil {
|
||||
r.redisClient.PublishCache(cacheKey, ttl, res)
|
||||
}
|
||||
|
||||
return &res, util.ToTTLDuration(ttl)
|
||||
}
|
||||
|
||||
func (r *CachingResolver) redisSubscriber(ctx context.Context) {
|
||||
|
@ -138,9 +138,11 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) {
|
|||
select {
|
||||
case rc := <-r.redisClient.CacheChannel:
|
||||
if rc != nil {
|
||||
logger.Debug("Received key from redis: ", rc.Key)
|
||||
ttl := r.adjustTTLs(rc.Response.Res.Answer)
|
||||
r.putInCache(ctx, rc.Key, rc.Response, ttl, false)
|
||||
_, domain := util.ExtractCacheKey(rc.Key)
|
||||
|
||||
logger.WithField("domain", util.Obfuscate(domain)).Debug("received from redis")
|
||||
|
||||
r.resultCache.Put(rc.Key, &rc.Entry, util.ToTTLDuration(rc.TTL))
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
|
@ -172,63 +174,57 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (
|
|||
cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain)
|
||||
logger := logger.WithField("domain", util.Obfuscate(domain))
|
||||
|
||||
val, ttl := r.getFromCache(logger, cacheKey)
|
||||
cacheEntry := r.getFromCache(logger, cacheKey)
|
||||
|
||||
if val != nil {
|
||||
if cacheEntry != nil {
|
||||
logger.Debug("domain is cached")
|
||||
|
||||
val.SetRcode(request.Req, val.Rcode)
|
||||
cacheEntry.SetRcode(request.Req, cacheEntry.Rcode)
|
||||
|
||||
// Adjust TTL
|
||||
setTTLInCachedResponse(val, ttl)
|
||||
|
||||
if val.Rcode == dns.RcodeSuccess {
|
||||
return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil
|
||||
if cacheEntry.Rcode == dns.RcodeSuccess {
|
||||
return &model.Response{Res: cacheEntry, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil
|
||||
}
|
||||
|
||||
return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
|
||||
return &model.Response{Res: cacheEntry, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
|
||||
}
|
||||
|
||||
logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver")
|
||||
response, err = r.next.Resolve(ctx, request)
|
||||
|
||||
response, err = r.next.Resolve(ctx, request)
|
||||
if err == nil {
|
||||
cacheTTL := r.adjustTTLs(response.Res.Answer)
|
||||
r.putInCache(ctx, cacheKey, response, cacheTTL, true)
|
||||
ttl, cacheEntry := r.createCacheEntry(logger, response.Res)
|
||||
if ttl != 0 && len(cacheEntry) > 0 {
|
||||
r.resultCache.Put(cacheKey, &cacheEntry, util.ToTTLDuration(ttl))
|
||||
|
||||
if r.redisClient != nil {
|
||||
r.redisClient.PublishCache(cacheKey, ttl, cacheEntry)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return response, err
|
||||
}
|
||||
|
||||
func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) (*dns.Msg, time.Duration) {
|
||||
val, ttl := r.resultCache.Get(key)
|
||||
if val == nil {
|
||||
return nil, 0
|
||||
func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) *dns.Msg {
|
||||
raw, ttl := r.resultCache.Get(key)
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
res := new(dns.Msg)
|
||||
|
||||
err := res.Unpack(*val)
|
||||
err := res.Unpack(*raw)
|
||||
if err != nil {
|
||||
logger.Error("can't unpack cached entry. Cache malformed?", err)
|
||||
|
||||
return nil, 0
|
||||
return nil
|
||||
}
|
||||
|
||||
return res, ttl
|
||||
}
|
||||
// Adjust TTL
|
||||
util.AdjustAnswerTTL(res, ttl)
|
||||
|
||||
func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) {
|
||||
minTTL := uint32(math.MaxInt32)
|
||||
// find smallest TTL first
|
||||
for _, rr := range resp.Answer {
|
||||
minTTL = min(minTTL, rr.Header().Ttl)
|
||||
}
|
||||
|
||||
for _, rr := range resp.Answer {
|
||||
rr.Header().Ttl = rr.Header().Ttl - minTTL + uint32(ttl.Seconds())
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// isRequestCacheable returns true if the request should be cached
|
||||
|
@ -244,72 +240,50 @@ func isRequestCacheable(request *model.Request) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// isResponseCacheable returns true if the response is not truncated and its CD flag isn't set.
|
||||
func isResponseCacheable(msg *dns.Msg) bool {
|
||||
// we don't cache truncated responses and responses with CD flag
|
||||
return !msg.Truncated && !msg.CheckingDisabled
|
||||
func (r *CachingResolver) modifyResponseTTL(response *dns.Msg) uint32 {
|
||||
// if response is empty or negative, return negative cache time from config
|
||||
if len(response.Answer) == 0 || response.Rcode == dns.RcodeNameError {
|
||||
return util.ToTTL(r.cfg.CacheTimeNegative)
|
||||
}
|
||||
|
||||
// if response is truncated or CD flag is set, return noCacheTTL since we don't cache these responses
|
||||
if response.Truncated || response.CheckingDisabled {
|
||||
return 0
|
||||
}
|
||||
|
||||
// if response is not successful, return noCacheTTL since we don't cache these responses
|
||||
if response.Rcode != dns.RcodeSuccess {
|
||||
return 0
|
||||
}
|
||||
|
||||
// adjust TTLs of all answers to match the configured min and max caching times
|
||||
util.SetAnswerMinMaxTTL(response, r.cfg.MinCachingTime, r.cfg.MaxCachingTime)
|
||||
|
||||
return util.GetAnswerMinTTL(response)
|
||||
}
|
||||
|
||||
func (r *CachingResolver) putInCache(
|
||||
ctx context.Context, cacheKey string, response *model.Response, ttl time.Duration, publish bool,
|
||||
) {
|
||||
respCopy := response.Res.Copy()
|
||||
func (r *CachingResolver) createCacheEntry(logger *logrus.Entry, input *dns.Msg) (uint32, []byte) {
|
||||
ttl := r.modifyResponseTTL(input)
|
||||
if ttl == 0 {
|
||||
logger.Debug("response is not cacheable")
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
internalMsg := input.Copy()
|
||||
internalMsg.Compress = true
|
||||
|
||||
// don't cache any EDNS OPT records
|
||||
util.RemoveEdns0Record(respCopy)
|
||||
util.RemoveEdns0Record(internalMsg)
|
||||
|
||||
packed, err := respCopy.Pack()
|
||||
util.LogOnError(ctx, "error on packing", err)
|
||||
packed, err := internalMsg.Pack()
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("response packing failed")
|
||||
|
||||
if err == nil {
|
||||
if response.Res.Rcode == dns.RcodeSuccess && isResponseCacheable(response.Res) {
|
||||
// put value into cache
|
||||
r.resultCache.Put(cacheKey, &packed, ttl)
|
||||
} else if response.Res.Rcode == dns.RcodeNameError {
|
||||
if r.cfg.CacheTimeNegative.IsAboveZero() {
|
||||
// put negative cache if result code is NXDOMAIN
|
||||
r.resultCache.Put(cacheKey, &packed, r.cfg.CacheTimeNegative.ToDuration())
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if publish && r.redisClient != nil {
|
||||
res := *respCopy
|
||||
r.redisClient.PublishCache(cacheKey, &res)
|
||||
}
|
||||
}
|
||||
|
||||
// adjustTTLs calculates and returns the min TTL (considers also the min and max cache time)
|
||||
// for all records from answer or a negative cache time for empty answer
|
||||
// adjust the TTL in the answer header accordingly
|
||||
func (r *CachingResolver) adjustTTLs(answer []dns.RR) (ttl time.Duration) {
|
||||
minTTL := uint32(math.MaxInt32)
|
||||
|
||||
if len(answer) == 0 {
|
||||
return r.cfg.CacheTimeNegative.ToDuration()
|
||||
}
|
||||
|
||||
for _, a := range answer {
|
||||
// if TTL < mitTTL -> adjust the value, set minTTL
|
||||
if r.cfg.MinCachingTime.IsAboveZero() {
|
||||
if atomic.LoadUint32(&a.Header().Ttl) < r.cfg.MinCachingTime.SecondsU32() {
|
||||
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MinCachingTime.SecondsU32())
|
||||
}
|
||||
}
|
||||
|
||||
if r.cfg.MaxCachingTime.IsAboveZero() {
|
||||
if atomic.LoadUint32(&a.Header().Ttl) > r.cfg.MaxCachingTime.SecondsU32() {
|
||||
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MaxCachingTime.SecondsU32())
|
||||
}
|
||||
}
|
||||
|
||||
headerTTL := atomic.LoadUint32(&a.Header().Ttl)
|
||||
if minTTL > headerTTL {
|
||||
minTTL = headerTTL
|
||||
}
|
||||
}
|
||||
|
||||
return time.Duration(minTTL) * time.Second
|
||||
return ttl, packed
|
||||
}
|
||||
|
||||
func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{}) {
|
||||
|
|
|
@ -782,13 +782,13 @@ var _ = Describe("CachingResolver", func() {
|
|||
request := newRequest("example2.com.", A)
|
||||
domain := util.ExtractDomain(request.Req.Question[0])
|
||||
cacheKey := util.GenerateCacheKey(A, domain)
|
||||
redisMockMsg := &redis.CacheMessage{
|
||||
Key: cacheKey,
|
||||
Response: &Response{
|
||||
RType: ResponseTypeCACHED,
|
||||
Reason: "MOCK_REDIS",
|
||||
Res: mockAnswer,
|
||||
},
|
||||
binMsg, err := mockAnswer.Pack()
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
redisMockMsg := &redis.CacheEntry{
|
||||
TTL: 123,
|
||||
Key: cacheKey,
|
||||
Entry: binMsg,
|
||||
}
|
||||
redisClient.CacheChannel <- redisMockMsg
|
||||
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// TTLInput is the input type for TTL values and consists of the following underlying types:
|
||||
// int, uint, uint32, int64
|
||||
type TTLInput interface {
|
||||
~int | ~uint | ~uint32 | ~int64
|
||||
}
|
||||
|
||||
// ToTTL converts the input to a TTL of seconds as uint32.
|
||||
//
|
||||
// If the input is of underlying type time.Duration, the value is converted to seconds.
|
||||
//
|
||||
// If the input is negative, the TTL is set to 0.
|
||||
//
|
||||
// If the input is greater than the maximum value of uint32, the TTL is set to math.MaxUint32.
|
||||
func ToTTL[T TTLInput](input T) uint32 {
|
||||
// fast return if the input is zero or below
|
||||
if input <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// fast return if the input is already of type uint32
|
||||
if ui32Type, ok := any(input).(uint32); ok {
|
||||
return ui32Type
|
||||
}
|
||||
|
||||
// use int64 as the intermediate type for conversion
|
||||
res := int64(input)
|
||||
|
||||
// check if the input is of underlying type time.Duration
|
||||
if durType, ok := any(input).(interface{ Seconds() float64 }); ok {
|
||||
res = int64(durType.Seconds())
|
||||
}
|
||||
|
||||
// check if the value is negative or greater than the maximum value of uint32
|
||||
if res < 0 {
|
||||
// there is no negative TTL
|
||||
return 0
|
||||
} else if res > math.MaxUint32 {
|
||||
// since TTL is a 32-bit unsigned integer, the maximum value is math.MaxUint32
|
||||
return math.MaxUint32
|
||||
}
|
||||
|
||||
// return the value as uint32
|
||||
return uint32(res)
|
||||
}
|
||||
|
||||
// ToTTLDuration converts the input to a time.Duration.
|
||||
//
|
||||
// The input is converted to a TTL of seconds as uint32 and then to a time.Duration.
|
||||
func ToTTLDuration[T TTLInput](input T) time.Duration {
|
||||
return time.Duration(ToTTL(input)) * time.Second
|
||||
}
|
||||
|
||||
// SetAnswerMinTTL sets the TTL of all answers in the message that are less than the specified minimum TTL to
|
||||
// the minimum TTL.
|
||||
func SetAnswerMinTTL[T TTLInput](msg *dns.Msg, min T) {
|
||||
if minTTL := ToTTL(min); minTTL != 0 {
|
||||
for _, answer := range msg.Answer {
|
||||
if atomic.LoadUint32(&answer.Header().Ttl) < minTTL {
|
||||
atomic.StoreUint32(&answer.Header().Ttl, minTTL)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetAnswerMaxTTL sets the TTL of all answers in the message that are greater than the specified maximum TTL
|
||||
// to the maximum TTL.
|
||||
func SetAnswerMaxTTL[T TTLInput](msg *dns.Msg, max T) {
|
||||
if maxTTL := ToTTL(max); maxTTL != 0 {
|
||||
for _, answer := range msg.Answer {
|
||||
if atomic.LoadUint32(&answer.Header().Ttl) > maxTTL {
|
||||
atomic.StoreUint32(&answer.Header().Ttl, maxTTL)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetAnswerMinMaxTTL sets the TTL of all answers in the message that are less than the specified minimum TTL
|
||||
// to the minimum TTL and the TTL of all answers that are greater than the specified maximum TTL to the maximum TTL.
|
||||
func SetAnswerMinMaxTTL[T, TT TTLInput](msg *dns.Msg, min T, max TT) {
|
||||
minTTL := ToTTL(min)
|
||||
maxTTL := ToTTL(max)
|
||||
|
||||
switch {
|
||||
case minTTL == 0 && maxTTL == 0:
|
||||
// no TTL specified, fast return
|
||||
return
|
||||
case minTTL != 0 && maxTTL == 0:
|
||||
// only minimum TTL specified
|
||||
SetAnswerMinTTL(msg, min)
|
||||
case minTTL == 0 && maxTTL != 0:
|
||||
// only maximum TTL specified
|
||||
SetAnswerMaxTTL(msg, max)
|
||||
default:
|
||||
// both minimum and maximum TTL specified
|
||||
for _, answer := range msg.Answer {
|
||||
headerTTL := atomic.LoadUint32(&answer.Header().Ttl)
|
||||
if headerTTL < minTTL {
|
||||
atomic.StoreUint32(&answer.Header().Ttl, minTTL)
|
||||
} else if headerTTL > maxTTL {
|
||||
atomic.StoreUint32(&answer.Header().Ttl, maxTTL)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetMinAnswerTTL returns the lowest TTL of all answers in the message.
|
||||
func GetAnswerMinTTL(msg *dns.Msg) uint32 {
|
||||
var minTTL atomic.Uint32
|
||||
// initialize minTTL with the maximum value of uint32
|
||||
minTTL.Store(math.MaxUint32)
|
||||
|
||||
for _, answer := range msg.Answer {
|
||||
headerTTL := atomic.LoadUint32(&answer.Header().Ttl)
|
||||
if headerTTL < minTTL.Load() {
|
||||
minTTL.Store(headerTTL)
|
||||
}
|
||||
}
|
||||
|
||||
return minTTL.Load()
|
||||
}
|
||||
|
||||
// AdjustAnswerTTL adjusts the TTL of all answers in the message by the difference between the lowest TTL
|
||||
// and the answer's TTL plus the specified adjustment.
|
||||
//
|
||||
// If the adjustment is zero, the TTL is not changed.
|
||||
func AdjustAnswerTTL[T TTLInput](msg *dns.Msg, adjustment T) {
|
||||
adjustmentTTL := ToTTL(adjustment)
|
||||
minTTL := GetAnswerMinTTL(msg)
|
||||
|
||||
for _, answer := range msg.Answer {
|
||||
headerTTL := atomic.LoadUint32(&answer.Header().Ttl)
|
||||
atomic.StoreUint32(&answer.Header().Ttl, headerTTL-minTTL+adjustmentTTL)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("EDNS0 utils", func() {
|
||||
DescribeTable("ToTTL",
|
||||
func(input interface{}, expected int) {
|
||||
res := uint32(0)
|
||||
switch it := input.(type) {
|
||||
case uint32:
|
||||
res = ToTTL(it)
|
||||
case int:
|
||||
res = ToTTL(it)
|
||||
case int64:
|
||||
res = ToTTL(it)
|
||||
case time.Duration:
|
||||
res = ToTTL(it)
|
||||
default:
|
||||
Fail("unsupported type")
|
||||
}
|
||||
|
||||
Expect(ToTTL(res)).Should(Equal(uint32(expected)))
|
||||
},
|
||||
Entry("should return 0 for negative input", -1, 0),
|
||||
Entry("should return uint32 for uint32 input", uint32(1), 1),
|
||||
Entry("should return uint32 for int input", 1, 1),
|
||||
Entry("should return uint32 for int64 input", int64(1), 1),
|
||||
Entry("should return seconds for time.Duration input", time.Second, 1),
|
||||
Entry("should return math.MaxUint32 for too large input", int64(math.MaxUint32)+1, math.MaxUint32),
|
||||
)
|
||||
})
|
Loading…
Reference in New Issue