This commit is contained in:
Kwitsch 2024-05-08 04:06:32 +03:00 committed by GitHub
commit 9898dbe1cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 417 additions and 314 deletions

View File

@ -103,14 +103,14 @@ func periodicCleanup[T any](ctx context.Context, c *ExpiringLRUCache[T]) {
for {
select {
case <-ticker.C:
c.cleanUp()
c.cleanUp(ctx)
case <-ctx.Done():
return
}
}
}
func (e *ExpiringLRUCache[T]) cleanUp() {
func (e *ExpiringLRUCache[T]) cleanUp(ctx context.Context) {
var expiredKeys []string
// check for expired items and collect expired keys
@ -126,7 +126,7 @@ func (e *ExpiringLRUCache[T]) cleanUp() {
var keysToDelete []string
for _, key := range expiredKeys {
newVal, newTTL := e.preExpirationFn(context.Background(), key)
newVal, newTTL := e.preExpirationFn(ctx, key)
if newVal != nil {
e.Put(key, newVal, newTTL)
} else {

View File

@ -181,7 +181,7 @@ var _ = Describe("Expiration cache", func() {
time.Sleep(2 * time.Millisecond)
// trigger cleanUp manually -> onExpiredFn will be executed, because element is expired
cache.cleanUp()
cache.cleanUp(ctx)
// wait for expiration
val, ttl := cache.Get("key1")

View File

@ -5,13 +5,11 @@ import (
"context"
"encoding/json"
"fmt"
"math"
"strings"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/go-redis/redis/v8"
"github.com/google/uuid"
@ -29,10 +27,10 @@ const (
messageTypeEnable = 1
)
// sendBuffer message
type bufferMessage struct {
Key string
Message *dns.Msg
type CacheEntry struct {
TTL uint32
Key string
Entry []byte
}
// redis pubsub message
@ -43,12 +41,6 @@ type redisMessage struct {
Client []byte `json:"c"`
}
// CacheChannel message
type CacheMessage struct {
Key string
Response *model.Response
}
type EnabledMessage struct {
State bool `json:"s"`
Duration time.Duration `json:"d,omitempty"`
@ -61,8 +53,8 @@ type Client struct {
client *redis.Client
l *logrus.Entry
id []byte
sendBuffer chan *bufferMessage
CacheChannel chan *CacheMessage
sendBuffer chan *CacheEntry
CacheChannel chan *CacheEntry
EnabledChannel chan *EnabledMessage
}
@ -100,55 +92,63 @@ func New(ctx context.Context, cfg *config.Redis) (*Client, error) {
rdb := baseClient.WithContext(ctx)
_, err := rdb.Ping(ctx).Result()
if err == nil {
var id []byte
id, err = uuid.New().MarshalBinary()
if err == nil {
// construct client
res := &Client{
config: cfg,
client: rdb,
l: log.PrefixedLog("redis"),
id: id,
sendBuffer: make(chan *bufferMessage, chanCap),
CacheChannel: make(chan *CacheMessage, chanCap),
EnabledChannel: make(chan *EnabledMessage, chanCap),
}
// start channel handling go routine
err = res.startup(ctx)
return res, err
}
if err != nil {
return nil, err
}
return nil, err
var id []byte
id, err = uuid.New().MarshalBinary()
if err != nil {
return nil, err
}
// construct client
res := &Client{
config: cfg,
client: rdb,
l: log.PrefixedLog("redis"),
id: id,
sendBuffer: make(chan *CacheEntry, chanCap),
CacheChannel: make(chan *CacheEntry, chanCap),
EnabledChannel: make(chan *EnabledMessage, chanCap),
}
// start channel handling go routine
err = res.startup(ctx)
if err != nil {
return nil, err
}
return res, nil
}
// PublishCache publish cache to redis async
func (c *Client) PublishCache(key string, message *dns.Msg) {
if len(key) > 0 && message != nil {
c.sendBuffer <- &bufferMessage{
Key: key,
Message: message,
// PublishCache publish cache entry to redis if key and message are not empty and ttl > 0
func (c *Client) PublishCache(key string, ttl uint32, message []byte) {
if len(key) > 0 && len(message) > 0 && ttl > 0 {
c.sendBuffer <- &CacheEntry{
TTL: ttl,
Key: key,
Entry: message,
}
}
}
func (c *Client) PublishEnabled(ctx context.Context, state *EnabledMessage) {
binState, sErr := json.Marshal(state)
if sErr == nil {
binMsg, mErr := json.Marshal(redisMessage{
Type: messageTypeEnable,
Message: binState,
Client: c.id,
})
if mErr == nil {
c.client.Publish(ctx, SyncChannelName, binMsg)
}
binState, err := json.Marshal(state)
if err != nil {
return
}
binMsg, err := json.Marshal(redisMessage{
Type: messageTypeEnable,
Message: binState,
Client: c.id,
})
if err != nil {
return
}
c.client.Publish(ctx, SyncChannelName, binMsg)
}
// GetRedisCache reads the redis cache and publish it to the channel
@ -183,56 +183,53 @@ func (c *Client) startup(ctx context.Context) error {
ps := c.client.Subscribe(ctx, SyncChannelName)
_, err := ps.Receive(ctx)
if err == nil {
go func() {
for {
select {
// received message from subscription
case msg := <-ps.Channel():
c.l.Debug("Received message: ", msg)
if msg != nil && len(msg.Payload) > 0 {
// message is not empty
c.processReceivedMessage(ctx, msg)
}
// publish message from buffer
case s := <-c.sendBuffer:
c.publishMessageFromBuffer(ctx, s)
// context is done
case <-ctx.Done():
c.client.Close()
return
}
}
}()
if err != nil {
return err
}
return err
go func() {
defer ps.Close()
defer c.client.Close()
for {
select {
// received message from subscription
case msg := <-ps.Channel():
c.l.Debug("Received message: ", msg)
if msg != nil && len(msg.Payload) > 0 {
// message is not empty
c.processReceivedMessage(ctx, msg)
}
// publish message from buffer
case s := <-c.sendBuffer:
c.publishMessageFromBuffer(ctx, s)
// context is done
case <-ctx.Done():
return
}
}
}()
return nil
}
func (c *Client) publishMessageFromBuffer(ctx context.Context, s *bufferMessage) {
origRes := s.Message
origRes.Compress = true
binRes, pErr := origRes.Pack()
if pErr == nil {
binMsg, mErr := json.Marshal(redisMessage{
Key: s.Key,
Type: messageTypeCache,
Message: binRes,
Client: c.id,
})
if mErr == nil {
c.client.Publish(ctx, SyncChannelName, binMsg)
}
c.client.Set(ctx,
prefixKey(s.Key),
binRes,
c.getTTL(origRes))
// publishMessageFromBuffer publishes a message from the buffer to the redis channel and stores it in the cache
func (c *Client) publishMessageFromBuffer(ctx context.Context, s *CacheEntry) {
psMsg, err := json.Marshal(redisMessage{
Key: s.Key,
Type: messageTypeCache,
Message: s.Entry,
Client: c.id,
})
if err == nil {
c.client.Publish(ctx, SyncChannelName, psMsg)
}
c.client.Set(ctx,
prefixKey(s.Key),
s.Entry,
util.ToTTLDuration(s.TTL))
}
func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) {
@ -248,26 +245,24 @@ func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message)
if !bytes.Equal(rm.Client, c.id) {
switch rm.Type {
case messageTypeCache:
var cm *CacheMessage
cm, err := convertMessage(&rm, 0)
cm, err := convertMessage(0, rm.Key, rm.Message)
if err != nil {
c.l.Error("Processing CacheMessage error: ", err)
c.l.Error(err)
return
}
util.CtxSend(ctx, c.CacheChannel, cm)
util.CtxSend(ctx, c.CacheChannel, &cm)
case messageTypeEnable:
var msg EnabledMessage
msg := new(EnabledMessage)
if err := json.Unmarshal(rm.Message, &msg); err != nil {
if err := json.Unmarshal(rm.Message, msg); err != nil {
c.l.Error("Processing EnabledMessage error: ", err)
return
}
util.CtxSend(ctx, c.EnabledChannel, &msg)
util.CtxSend(ctx, c.EnabledChannel, msg)
default:
c.l.Warn("Unknown message type: ", rm.Type)
}
@ -275,69 +270,48 @@ func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message)
}
// getResponse returns model.Response for a key
func (c *Client) getResponse(ctx context.Context, key string) (*CacheMessage, error) {
func (c *Client) getResponse(ctx context.Context, key string) (*CacheEntry, error) {
resp, err := c.client.Get(ctx, key).Result()
if err == nil {
var ttl time.Duration
ttl, err = c.client.TTL(ctx, key).Result()
if err == nil {
var result *CacheMessage
result, err = convertMessage(&redisMessage{
Key: cleanKey(key),
Message: []byte(resp),
}, ttl)
if err != nil {
return nil, fmt.Errorf("conversion error: %w", err)
}
return result, nil
}
if err != nil {
return nil, err
}
return nil, err
ttl, err := c.client.TTL(ctx, key).Result()
if err != nil {
return nil, err
}
result, err := convertMessage(ttl, cleanKey(key), []byte(resp))
if err != nil {
return nil, err
}
return &result, nil
}
// convertMessage converts redisMessage to CacheMessage
func convertMessage(message *redisMessage, ttl time.Duration) (*CacheMessage, error) {
msg := dns.Msg{}
err := msg.Unpack(message.Message)
if err == nil {
if ttl > 0 {
for _, a := range msg.Answer {
a.Header().Ttl = uint32(ttl.Seconds())
}
}
res := &CacheMessage{
Key: message.Key,
Response: &model.Response{
RType: model.ResponseTypeCACHED,
Reason: cacheReason,
Res: &msg,
},
}
return res, nil
func convertMessage[T util.TTLInput](ttl T, key string, message []byte) (CacheEntry, error) {
res := CacheEntry{
TTL: util.ToTTL(ttl),
Key: key,
Entry: message,
}
return nil, err
}
packErr := fmt.Errorf("invalid message for key %s", key)
if len(message) == 0 {
return res, packErr
}
// getTTL of dns message or return defaultCacheTime if 0
func (c *Client) getTTL(dns *dns.Msg) time.Duration {
ttl := uint32(math.MaxInt32)
for _, a := range dns.Answer {
ttl = min(ttl, a.Header().Ttl)
msg := new(dns.Msg)
if err := msg.Unpack(message); err != nil {
return res, packErr
}
if ttl == 0 {
return defaultCacheTime
res.TTL = util.GetAnswerMinTTL(msg)
}
return time.Duration(ttl) * time.Second
return res, nil
}
// prefixKey with CacheStorePrefix

View File

@ -96,10 +96,12 @@ var _ = Describe("Redis client", func() {
By("publish new message with TTL > 0", func() {
res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
Expect(err).Should(Succeed())
redisClient.PublishCache("example.com", res)
binRes, err := res.Pack()
Expect(err).Should(Succeed())
redisClient.PublishCache("example.com", 123, binRes)
})
By("Database has one entry with correct TTL", func() {
@ -111,34 +113,6 @@ var _ = Describe("Redis client", func() {
Expect(ttl.Seconds()).Should(BeNumerically("~", 123))
})
})
It("One new entry with default TTL should be persisted in the database", func(ctx context.Context) {
redisClient, err = New(ctx, redisConfig)
Expect(err).Should(Succeed())
By("Database is empty", func() {
Eventually(func() []string {
return redisServer.DB(redisConfig.Database).Keys()
}).Should(BeEmpty())
})
By("publish new message with TTL = 0", func() {
res, err := util.NewMsgWithAnswer("example.com.", 0, dns.Type(dns.TypeA), "123.124.122.123")
Expect(err).Should(Succeed())
redisClient.PublishCache("example.com", res)
})
By("Database has one entry with default TTL", func() {
Eventually(func() bool {
return redisServer.DB(redisConfig.Database).Exists(exampleComKey)
}).Should(BeTrue())
ttl := redisServer.DB(redisConfig.Database).TTL(exampleComKey)
Expect(ttl.Seconds()).Should(BeNumerically("~", defaultCacheTime.Seconds()))
})
})
})
When("Redis client publishes 'enabled' message", func() {
It("should propagate the message over redis", func(ctx context.Context) {
@ -222,7 +196,7 @@ var _ = Describe("Redis client", func() {
rec := redisServer.Publish(SyncChannelName, string(binMsg))
Expect(rec).Should(Equal(1))
Eventually(func() chan *CacheMessage {
Eventually(func() chan *CacheEntry {
return redisClient.CacheChannel
}).Should(HaveLen(lenE + 1))
}, SpecTimeout(time.Second*6))
@ -255,7 +229,7 @@ var _ = Describe("Redis client", func() {
return redisClient.EnabledChannel
}).Should(HaveLen(lenE))
Eventually(func() chan *CacheMessage {
Eventually(func() chan *CacheEntry {
return redisClient.CacheChannel
}).Should(HaveLen(lenC))
}, SpecTimeout(time.Second*6))
@ -288,7 +262,7 @@ var _ = Describe("Redis client", func() {
return redisClient.EnabledChannel
}).Should(HaveLen(lenE))
Eventually(func() chan *CacheMessage {
Eventually(func() chan *CacheEntry {
return redisClient.CacheChannel
}).Should(HaveLen(lenC))
}, SpecTimeout(time.Second*6))
@ -312,13 +286,13 @@ var _ = Describe("Redis client", func() {
})
By("Put valid data in Redis by publishing the cache entry", func() {
var res *dns.Msg
res, err = util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
Expect(err).Should(Succeed())
redisClient.PublishCache("example.com", res)
binRes, err := res.Pack()
Expect(err).Should(Succeed())
redisClient.PublishCache("example.com", 123, binRes)
})
By("Database has one entry now", func() {

View File

@ -2,9 +2,6 @@ package resolver
import (
"context"
"fmt"
"math"
"sync/atomic"
"time"
"github.com/0xERR0R/blocky/cache/expirationcache"
@ -18,7 +15,9 @@ import (
"github.com/sirupsen/logrus"
)
const defaultCachingCleanUpInterval = 5 * time.Second
const (
defaultCachingCleanUpInterval = 5 * time.Second
)
// CachingResolver caches answers from dns queries with their TTL time,
// to avoid external resolver calls for recurrent queries
@ -107,28 +106,29 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin
func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) {
qType, domainName := util.ExtractCacheKey(cacheKey)
ctx, logger := r.log(ctx)
logger = logger.WithField("domain", util.Obfuscate(domainName))
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
logger.Debugf("prefetching %s", qType)
req := newRequest(dns.Fqdn(domainName), qType)
response, err := r.next.Resolve(ctx, req)
if err != nil {
logger.WithError(err).Warn("cache prefetch failed")
if err == nil {
if response.Res.Rcode == dns.RcodeSuccess {
packed, err := response.Res.Pack()
if err != nil {
logger.Error("unable to pack response", err)
return nil, 0
}
return &packed, r.adjustTTLs(response.Res.Answer)
}
} else {
util.LogOnError(ctx, fmt.Sprintf("can't prefetch '%s' ", domainName), err)
return nil, 0
}
return nil, 0
ttl, res := r.createCacheEntry(logger, response.Res)
if ttl == 0 || len(res) == 0 {
return nil, 0
}
if r.redisClient != nil {
r.redisClient.PublishCache(cacheKey, ttl, res)
}
return &res, util.ToTTLDuration(ttl)
}
func (r *CachingResolver) redisSubscriber(ctx context.Context) {
@ -138,9 +138,11 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) {
select {
case rc := <-r.redisClient.CacheChannel:
if rc != nil {
logger.Debug("Received key from redis: ", rc.Key)
ttl := r.adjustTTLs(rc.Response.Res.Answer)
r.putInCache(ctx, rc.Key, rc.Response, ttl, false)
_, domain := util.ExtractCacheKey(rc.Key)
logger.WithField("domain", util.Obfuscate(domain)).Debug("received from redis")
r.resultCache.Put(rc.Key, &rc.Entry, util.ToTTLDuration(rc.TTL))
}
case <-ctx.Done():
@ -172,63 +174,57 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (
cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain)
logger := logger.WithField("domain", util.Obfuscate(domain))
val, ttl := r.getFromCache(logger, cacheKey)
cacheEntry := r.getFromCache(logger, cacheKey)
if val != nil {
if cacheEntry != nil {
logger.Debug("domain is cached")
val.SetRcode(request.Req, val.Rcode)
cacheEntry.SetRcode(request.Req, cacheEntry.Rcode)
// Adjust TTL
setTTLInCachedResponse(val, ttl)
if val.Rcode == dns.RcodeSuccess {
return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil
if cacheEntry.Rcode == dns.RcodeSuccess {
return &model.Response{Res: cacheEntry, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil
}
return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
return &model.Response{Res: cacheEntry, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
}
logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver")
response, err = r.next.Resolve(ctx, request)
response, err = r.next.Resolve(ctx, request)
if err == nil {
cacheTTL := r.adjustTTLs(response.Res.Answer)
r.putInCache(ctx, cacheKey, response, cacheTTL, true)
ttl, cacheEntry := r.createCacheEntry(logger, response.Res)
if ttl != 0 && len(cacheEntry) > 0 {
r.resultCache.Put(cacheKey, &cacheEntry, util.ToTTLDuration(ttl))
if r.redisClient != nil {
r.redisClient.PublishCache(cacheKey, ttl, cacheEntry)
}
}
}
}
return response, err
}
func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) (*dns.Msg, time.Duration) {
val, ttl := r.resultCache.Get(key)
if val == nil {
return nil, 0
func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) *dns.Msg {
raw, ttl := r.resultCache.Get(key)
if raw == nil {
return nil
}
res := new(dns.Msg)
err := res.Unpack(*val)
err := res.Unpack(*raw)
if err != nil {
logger.Error("can't unpack cached entry. Cache malformed?", err)
return nil, 0
return nil
}
return res, ttl
}
// Adjust TTL
util.AdjustAnswerTTL(res, ttl)
func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) {
minTTL := uint32(math.MaxInt32)
// find smallest TTL first
for _, rr := range resp.Answer {
minTTL = min(minTTL, rr.Header().Ttl)
}
for _, rr := range resp.Answer {
rr.Header().Ttl = rr.Header().Ttl - minTTL + uint32(ttl.Seconds())
}
return res
}
// isRequestCacheable returns true if the request should be cached
@ -244,72 +240,50 @@ func isRequestCacheable(request *model.Request) bool {
return true
}
// isResponseCacheable returns true if the response is not truncated and its CD flag isn't set.
func isResponseCacheable(msg *dns.Msg) bool {
// we don't cache truncated responses and responses with CD flag
return !msg.Truncated && !msg.CheckingDisabled
func (r *CachingResolver) modifyResponseTTL(response *dns.Msg) uint32 {
// if response is empty or negative, return negative cache time from config
if len(response.Answer) == 0 || response.Rcode == dns.RcodeNameError {
return util.ToTTL(r.cfg.CacheTimeNegative)
}
// if response is truncated or CD flag is set, return noCacheTTL since we don't cache these responses
if response.Truncated || response.CheckingDisabled {
return 0
}
// if response is not successful, return noCacheTTL since we don't cache these responses
if response.Rcode != dns.RcodeSuccess {
return 0
}
// adjust TTLs of all answers to match the configured min and max caching times
util.SetAnswerMinMaxTTL(response, r.cfg.MinCachingTime, r.cfg.MaxCachingTime)
return util.GetAnswerMinTTL(response)
}
func (r *CachingResolver) putInCache(
ctx context.Context, cacheKey string, response *model.Response, ttl time.Duration, publish bool,
) {
respCopy := response.Res.Copy()
func (r *CachingResolver) createCacheEntry(logger *logrus.Entry, input *dns.Msg) (uint32, []byte) {
ttl := r.modifyResponseTTL(input)
if ttl == 0 {
logger.Debug("response is not cacheable")
return 0, nil
}
internalMsg := input.Copy()
internalMsg.Compress = true
// don't cache any EDNS OPT records
util.RemoveEdns0Record(respCopy)
util.RemoveEdns0Record(internalMsg)
packed, err := respCopy.Pack()
util.LogOnError(ctx, "error on packing", err)
packed, err := internalMsg.Pack()
if err != nil {
logger.WithError(err).Warn("response packing failed")
if err == nil {
if response.Res.Rcode == dns.RcodeSuccess && isResponseCacheable(response.Res) {
// put value into cache
r.resultCache.Put(cacheKey, &packed, ttl)
} else if response.Res.Rcode == dns.RcodeNameError {
if r.cfg.CacheTimeNegative.IsAboveZero() {
// put negative cache if result code is NXDOMAIN
r.resultCache.Put(cacheKey, &packed, r.cfg.CacheTimeNegative.ToDuration())
}
}
return 0, nil
}
if publish && r.redisClient != nil {
res := *respCopy
r.redisClient.PublishCache(cacheKey, &res)
}
}
// adjustTTLs calculates and returns the min TTL (considers also the min and max cache time)
// for all records from answer or a negative cache time for empty answer
// adjust the TTL in the answer header accordingly
func (r *CachingResolver) adjustTTLs(answer []dns.RR) (ttl time.Duration) {
minTTL := uint32(math.MaxInt32)
if len(answer) == 0 {
return r.cfg.CacheTimeNegative.ToDuration()
}
for _, a := range answer {
// if TTL < mitTTL -> adjust the value, set minTTL
if r.cfg.MinCachingTime.IsAboveZero() {
if atomic.LoadUint32(&a.Header().Ttl) < r.cfg.MinCachingTime.SecondsU32() {
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MinCachingTime.SecondsU32())
}
}
if r.cfg.MaxCachingTime.IsAboveZero() {
if atomic.LoadUint32(&a.Header().Ttl) > r.cfg.MaxCachingTime.SecondsU32() {
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MaxCachingTime.SecondsU32())
}
}
headerTTL := atomic.LoadUint32(&a.Header().Ttl)
if minTTL > headerTTL {
minTTL = headerTTL
}
}
return time.Duration(minTTL) * time.Second
return ttl, packed
}
func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{}) {

View File

@ -782,13 +782,13 @@ var _ = Describe("CachingResolver", func() {
request := newRequest("example2.com.", A)
domain := util.ExtractDomain(request.Req.Question[0])
cacheKey := util.GenerateCacheKey(A, domain)
redisMockMsg := &redis.CacheMessage{
Key: cacheKey,
Response: &Response{
RType: ResponseTypeCACHED,
Reason: "MOCK_REDIS",
Res: mockAnswer,
},
binMsg, err := mockAnswer.Pack()
Expect(err).Should(Succeed())
redisMockMsg := &redis.CacheEntry{
TTL: 123,
Key: cacheKey,
Entry: binMsg,
}
redisClient.CacheChannel <- redisMockMsg

144
util/dns.go Normal file
View File

@ -0,0 +1,144 @@
package util
import (
"math"
"sync/atomic"
"time"
"github.com/miekg/dns"
)
// TTLInput is the input type for TTL values and consists of the following underlying types:
// int, uint, uint32, int64
type TTLInput interface {
~int | ~uint | ~uint32 | ~int64
}
// ToTTL converts the input to a TTL of seconds as uint32.
//
// If the input is of underlying type time.Duration, the value is converted to seconds.
//
// If the input is negative, the TTL is set to 0.
//
// If the input is greater than the maximum value of uint32, the TTL is set to math.MaxUint32.
func ToTTL[T TTLInput](input T) uint32 {
// fast return if the input is zero or below
if input <= 0 {
return 0
}
// fast return if the input is already of type uint32
if ui32Type, ok := any(input).(uint32); ok {
return ui32Type
}
// use int64 as the intermediate type for conversion
res := int64(input)
// check if the input is of underlying type time.Duration
if durType, ok := any(input).(interface{ Seconds() float64 }); ok {
res = int64(durType.Seconds())
}
// check if the value is negative or greater than the maximum value of uint32
if res < 0 {
// there is no negative TTL
return 0
} else if res > math.MaxUint32 {
// since TTL is a 32-bit unsigned integer, the maximum value is math.MaxUint32
return math.MaxUint32
}
// return the value as uint32
return uint32(res)
}
// ToTTLDuration converts the input to a time.Duration.
//
// The input is converted to a TTL of seconds as uint32 and then to a time.Duration.
func ToTTLDuration[T TTLInput](input T) time.Duration {
return time.Duration(ToTTL(input)) * time.Second
}
// SetAnswerMinTTL sets the TTL of all answers in the message that are less than the specified minimum TTL to
// the minimum TTL.
func SetAnswerMinTTL[T TTLInput](msg *dns.Msg, min T) {
if minTTL := ToTTL(min); minTTL != 0 {
for _, answer := range msg.Answer {
if atomic.LoadUint32(&answer.Header().Ttl) < minTTL {
atomic.StoreUint32(&answer.Header().Ttl, minTTL)
}
}
}
}
// SetAnswerMaxTTL sets the TTL of all answers in the message that are greater than the specified maximum TTL
// to the maximum TTL.
func SetAnswerMaxTTL[T TTLInput](msg *dns.Msg, max T) {
if maxTTL := ToTTL(max); maxTTL != 0 {
for _, answer := range msg.Answer {
if atomic.LoadUint32(&answer.Header().Ttl) > maxTTL {
atomic.StoreUint32(&answer.Header().Ttl, maxTTL)
}
}
}
}
// SetAnswerMinMaxTTL sets the TTL of all answers in the message that are less than the specified minimum TTL
// to the minimum TTL and the TTL of all answers that are greater than the specified maximum TTL to the maximum TTL.
func SetAnswerMinMaxTTL[T, TT TTLInput](msg *dns.Msg, min T, max TT) {
minTTL := ToTTL(min)
maxTTL := ToTTL(max)
switch {
case minTTL == 0 && maxTTL == 0:
// no TTL specified, fast return
return
case minTTL != 0 && maxTTL == 0:
// only minimum TTL specified
SetAnswerMinTTL(msg, min)
case minTTL == 0 && maxTTL != 0:
// only maximum TTL specified
SetAnswerMaxTTL(msg, max)
default:
// both minimum and maximum TTL specified
for _, answer := range msg.Answer {
headerTTL := atomic.LoadUint32(&answer.Header().Ttl)
if headerTTL < minTTL {
atomic.StoreUint32(&answer.Header().Ttl, minTTL)
} else if headerTTL > maxTTL {
atomic.StoreUint32(&answer.Header().Ttl, maxTTL)
}
}
}
}
// GetMinAnswerTTL returns the lowest TTL of all answers in the message.
func GetAnswerMinTTL(msg *dns.Msg) uint32 {
var minTTL atomic.Uint32
// initialize minTTL with the maximum value of uint32
minTTL.Store(math.MaxUint32)
for _, answer := range msg.Answer {
headerTTL := atomic.LoadUint32(&answer.Header().Ttl)
if headerTTL < minTTL.Load() {
minTTL.Store(headerTTL)
}
}
return minTTL.Load()
}
// AdjustAnswerTTL adjusts the TTL of all answers in the message by the difference between the lowest TTL
// and the answer's TTL plus the specified adjustment.
//
// If the adjustment is zero, the TTL is not changed.
func AdjustAnswerTTL[T TTLInput](msg *dns.Msg, adjustment T) {
adjustmentTTL := ToTTL(adjustment)
minTTL := GetAnswerMinTTL(msg)
for _, answer := range msg.Answer {
headerTTL := atomic.LoadUint32(&answer.Header().Ttl)
atomic.StoreUint32(&answer.Header().Ttl, headerTTL-minTTL+adjustmentTTL)
}
}

37
util/dns_test.go Normal file
View File

@ -0,0 +1,37 @@
package util
import (
"math"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("EDNS0 utils", func() {
DescribeTable("ToTTL",
func(input interface{}, expected int) {
res := uint32(0)
switch it := input.(type) {
case uint32:
res = ToTTL(it)
case int:
res = ToTTL(it)
case int64:
res = ToTTL(it)
case time.Duration:
res = ToTTL(it)
default:
Fail("unsupported type")
}
Expect(ToTTL(res)).Should(Equal(uint32(expected)))
},
Entry("should return 0 for negative input", -1, 0),
Entry("should return uint32 for uint32 input", uint32(1), 1),
Entry("should return uint32 for int input", 1, 1),
Entry("should return uint32 for int64 input", int64(1), 1),
Entry("should return seconds for time.Duration input", time.Second, 1),
Entry("should return math.MaxUint32 for too large input", int64(math.MaxUint32)+1, math.MaxUint32),
)
})