WIP: switch to rueidis

This commit is contained in:
Dimitri Herzog 2023-03-18 22:47:48 +01:00
parent a05441f452
commit e106719fcb
13 changed files with 224 additions and 475 deletions

View File

@ -0,0 +1,89 @@
package expirationcache
import (
"context"
"fmt"
"strings"
"time"
"github.com/miekg/dns"
"github.com/rueian/rueidis"
)
type RedisCache struct {
rdb rueidis.Client
name string
}
func Key(k ...string) string {
return fmt.Sprintf("blocky:%s", strings.Join(k, ":"))
}
func (r *RedisCache) cacheKey(key string) string {
return Key("cache", r.name, key)
}
func toSeconds(t time.Duration) int64 {
return int64(t.Seconds())
}
func NoResult(res rueidis.RedisResult) bool {
err := res.Error()
return err != nil && rueidis.IsRedisNil(err)
}
func (r *RedisCache) Put(key string, val *dns.Msg, expiration time.Duration) {
b, err := val.Pack()
if err != nil {
panic(err)
}
cmd := r.rdb.B().Setex().Key(r.cacheKey(key)).Seconds(toSeconds(expiration)).Value(rueidis.BinaryString(b)).Build()
r.rdb.Do(context.Background(), cmd).Error()
if err != nil {
panic(err)
}
}
func (r *RedisCache) Get(key string) (val *dns.Msg, expiration time.Duration) {
cmd := r.rdb.B().Get().Key(r.cacheKey(key)).Cache()
resp := r.rdb.DoCache(context.Background(), cmd, 600*time.Second)
if NoResult(resp) {
return nil, 0
}
err := resp.Error()
if err != nil {
panic(err)
}
respStr, err := resp.ToString()
if err != nil {
panic(err)
}
bytesVal := []byte(respStr)
msg := new(dns.Msg)
err = msg.Unpack(bytesVal)
if err != nil {
panic(err)
}
return msg, time.Duration(resp.CacheTTL() * int64(time.Second))
}
func (r *RedisCache) TotalCount() int {
// TODO implement me
return 0
}
func (r *RedisCache) Clear() {
// TODO implement me
}
func NewRedisCache(rdb rueidis.Client, name string) ExpiringCache[dns.Msg] {
return &RedisCache{
rdb: rdb,
name: name,
}
}

View File

@ -3,23 +3,20 @@ package stringcache
import (
"context"
"fmt"
"time"
"github.com/go-redis/redis/v8"
"github.com/rueian/rueidis"
)
type RedisGroupedStringCache struct {
rdb *redis.Client
rdb rueidis.Client
name string
}
func NewRedisGroupedStringCache(name string) *RedisGroupedStringCache {
func NewRedisGroupedStringCache(name string, rdb rueidis.Client) *RedisGroupedStringCache {
return &RedisGroupedStringCache{
name: name,
rdb: redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "", // no password set
DB: 0, // use default DB
}),
rdb: rdb,
}
}
@ -28,25 +25,28 @@ func (r *RedisGroupedStringCache) cacheKey(groupName string) string {
}
func (r *RedisGroupedStringCache) ElementCount(group string) int {
return int(r.rdb.SCard(context.Background(), r.cacheKey(group)).Val())
res, err := r.rdb.DoCache(context.Background(), r.rdb.B().Scard().Key(r.cacheKey(group)).Cache(), 600*time.Second).ToInt64()
if err != nil {
return 0
}
return int(res)
}
func (r *RedisGroupedStringCache) Contains(searchString string, groups []string) []string {
cmds, err := r.rdb.Pipelined(context.Background(), func(pipeline redis.Pipeliner) error {
for _, group := range groups {
pipeline.SIsMember(context.Background(), r.cacheKey(group), searchString)
}
return nil
})
if err != nil {
panic(err)
var cmds []rueidis.CacheableTTL
for _, group := range groups {
cmds = append(cmds, rueidis.CT(r.rdb.B().Sismember().Key(r.cacheKey(group)).Member(searchString).Cache(), time.Second))
}
resps := r.rdb.DoMultiCache(context.Background(), cmds...)
var result []string
for ix, group := range groups {
if cmds[ix].(*redis.BoolCmd).Val() {
r, err := resps[ix].AsBool()
if err != nil {
panic(err)
}
if r {
result = append(result, group)
}
}
@ -55,30 +55,26 @@ func (r *RedisGroupedStringCache) Contains(searchString string, groups []string)
}
func (r *RedisGroupedStringCache) Refresh(group string) GroupFactory {
pipeline := r.rdb.Pipeline()
pipeline.Del(context.Background(), r.cacheKey(group))
cmds := rueidis.Commands{r.rdb.B().Del().Key(r.cacheKey(group)).Build()}
f := &RedisGroupFactory{
rdb: r.rdb,
name: r.cacheKey(group),
pipeline: pipeline,
rdb: r.rdb,
name: r.cacheKey(group),
cmds: cmds,
}
return f
}
type RedisGroupFactory struct {
rdb *redis.Client
name string
pipeline redis.Pipeliner
cnt int
rdb rueidis.Client
name string
cmds rueidis.Commands
cnt int
}
func (r *RedisGroupFactory) AddEntry(entry string) {
err := r.pipeline.SAdd(context.Background(), r.name, entry).Err()
if err != nil {
panic(err)
}
r.cmds = append(r.cmds, r.rdb.B().Sadd().Key(r.name).Member(entry).Build())
r.cnt++
}
@ -88,8 +84,5 @@ func (r *RedisGroupFactory) Count() int {
}
func (r *RedisGroupFactory) Finish() {
_, err := r.pipeline.Exec(context.Background())
if err != nil {
panic(err)
}
_ = r.rdb.DoMulti(context.Background(), r.cmds...)
}

View File

@ -10,9 +10,7 @@ import (
)
var _ = Describe("BlockingConfig", func() {
var (
cfg BlockingConfig
)
var cfg BlockingConfig
suiteBeforeEach()

View File

@ -232,16 +232,16 @@ type (
// RedisConfig configuration for the redis connection
type RedisConfig struct {
Address string `yaml:"address"`
Username string `yaml:"username" default:""`
Password string `yaml:"password" default:""`
Database int `yaml:"database" default:"0"`
Required bool `yaml:"required" default:"false"`
ConnectionAttempts int `yaml:"connectionAttempts" default:"3"`
ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"`
SentinelUsername string `yaml:"sentinelUsername" default:""`
SentinelPassword string `yaml:"sentinelPassword" default:""`
SentinelAddresses []string `yaml:"sentinelAddresses"`
Addresses []string `yaml:"addresses"`
Username string `yaml:"username" default:""`
Password string `yaml:"password" default:""`
Database int `yaml:"database" default:"0"`
SentinelUsername string `yaml:"sentinelUsername" default:""`
SentinelPassword string `yaml:"sentinelPassword" default:""`
SentinelMasterSet string `yaml:"sentinelMasterSet" default:""`
ClientMaxCachingTime Duration `yaml:"clientMaxCachingTime" default:"1h"`
ConnectionAttempts int `yaml:"connectionAttempts" default:"3"`
ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"`
}
type (

3
go.mod
View File

@ -55,7 +55,7 @@ require (
github.com/docker/go-units v0.5.0 // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect
github.com/jackc/pgx/v5 v5.3.0 // indirect
github.com/klauspost/compress v1.11.13 // indirect
github.com/magiconair/properties v1.8.7 // indirect
@ -66,6 +66,7 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0-rc2 // indirect
github.com/opencontainers/runc v1.1.3 // indirect
github.com/rueian/rueidis v0.0.95 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/tools/cmd/cover v0.1.0-deprecated // indirect
google.golang.org/genproto v0.0.0-20220617124728-180714bec0ad // indirect

4
go.sum
View File

@ -229,6 +229,8 @@ github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hf
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U=
github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
@ -396,6 +398,8 @@ github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6L
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rueian/rueidis v0.0.95 h1:zrqNGlIH+YYokKxCCSEnOCIrjFVFb29IqNurQEVqXMc=
github.com/rueian/rueidis v0.0.95/go.mod h1:Ziv7p67TsXd3zAcQ5+BDdQvOsUE9N68u6JcA7/rE+x8=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=

View File

@ -11,6 +11,10 @@ import (
"strings"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/util"
"github.com/hashicorp/go-multierror"
"github.com/sirupsen/logrus"
@ -70,11 +74,23 @@ func NewListCache(t ListCacheType, groupToLinks map[string][]string, refreshPeri
processingConcurrency = defaultProcessingConcurrency
}
b := &ListCache{
groupedCache: stringcache.NewChainedGroupedCache(
rdb, err := redis.NewRedisClient(&config.GetConfig().Redis)
util.FatalOnError("can't create redis client", err)
var groupedCache stringcache.GroupedStringCache
if rdb != nil {
// redis
groupedCache = stringcache.NewChainedGroupedCache(
stringcache.NewRedisGroupedStringCache(t.String(), rdb),
stringcache.NewInMemoryGroupedRegexCache())
} else {
// in-memory
groupedCache = stringcache.NewChainedGroupedCache(
stringcache.NewInMemoryGroupedStringCache(),
stringcache.NewInMemoryGroupedRegexCache(),
),
stringcache.NewInMemoryGroupedRegexCache())
}
b := &ListCache{
groupedCache: groupedCache,
groupToLinks: groupToLinks,
refreshPeriod: refreshPeriod,
downloader: downloader,

View File

@ -9,10 +9,8 @@ import (
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/go-redis/redis/v8"
"github.com/google/uuid"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
@ -65,67 +63,6 @@ type Client struct {
EnabledChannel chan *EnabledMessage
}
// New creates a new redis client
func New(cfg *config.RedisConfig) (*Client, error) {
// disable redis if no address is provided
if cfg == nil || len(cfg.Address) == 0 {
return nil, nil //nolint:nilnil
}
var rdb *redis.Client
if len(cfg.SentinelAddresses) > 0 {
rdb = redis.NewFailoverClient(&redis.FailoverOptions{
MasterName: cfg.Address,
SentinelUsername: cfg.Username,
SentinelPassword: cfg.SentinelPassword,
SentinelAddrs: cfg.SentinelAddresses,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.Database,
MaxRetries: cfg.ConnectionAttempts,
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
})
} else {
rdb = redis.NewClient(&redis.Options{
Addr: cfg.Address,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.Database,
MaxRetries: cfg.ConnectionAttempts,
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
})
}
ctx := context.Background()
_, err := rdb.Ping(ctx).Result()
if err == nil {
var id []byte
id, err = uuid.New().MarshalBinary()
if err == nil {
// construct client
res := &Client{
config: cfg,
client: rdb,
l: log.PrefixedLog("redis"),
ctx: ctx,
id: id,
sendBuffer: make(chan *bufferMessage, chanCap),
CacheChannel: make(chan *CacheMessage, chanCap),
EnabledChannel: make(chan *EnabledMessage, chanCap),
}
// start channel handling go routine
err = res.startup()
return res, err
}
}
return nil, err
}
// PublishCache publish cache to redis async
func (c *Client) PublishCache(key string, message *dns.Msg) {
if len(key) > 0 && message != nil {

View File

@ -1,312 +0,0 @@
package redis
import (
"encoding/json"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/util"
"github.com/alicebob/miniredis/v2"
"github.com/creasty/defaults"
"github.com/google/uuid"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var (
redisServer *miniredis.Miniredis
redisClient *Client
redisConfig *config.RedisConfig
err error
)
var _ = Describe("Redis client", func() {
BeforeEach(func() {
redisServer, err = miniredis.Run()
Expect(err).Should(Succeed())
DeferCleanup(redisServer.Close)
var rcfg config.RedisConfig
err = defaults.Set(&rcfg)
Expect(err).Should(Succeed())
rcfg.Address = redisServer.Addr()
redisConfig = &rcfg
redisClient, err = New(redisConfig)
Expect(err).Should(Succeed())
Expect(redisClient).ShouldNot(BeNil())
})
Describe("Client creation", func() {
When("redis configuration has no address", func() {
It("should return nil without error", func() {
var rcfg config.RedisConfig
err = defaults.Set(&rcfg)
Expect(err).Should(Succeed())
Expect(New(&rcfg)).Should(BeNil())
})
})
When("redis configuration has invalid address", func() {
It("should fail with error", func() {
var rcfg config.RedisConfig
err = defaults.Set(&rcfg)
Expect(err).Should(Succeed())
rcfg.Address = "127.0.0.1:0"
_, err = New(&rcfg)
Expect(err).Should(HaveOccurred())
})
})
When("redis configuration has invalid password", func() {
It("should fail with error", func() {
var rcfg config.RedisConfig
err = defaults.Set(&rcfg)
Expect(err).Should(Succeed())
rcfg.Address = redisServer.Addr()
rcfg.Password = "wrong"
_, err = New(&rcfg)
Expect(err).Should(HaveOccurred())
})
})
})
Describe("Publish message", func() {
When("Redis client publishes 'cache' message", func() {
It("One new entry with TTL > 0 should be persisted in the database", func() {
By("Database is empty", func() {
Eventually(func() []string {
return redisServer.DB(redisConfig.Database).Keys()
}).Should(BeEmpty())
})
By("publish new message with TTL > 0", func() {
res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
Expect(err).Should(Succeed())
redisClient.PublishCache("example.com", res)
})
By("Database has one entry with correct TTL", func() {
Eventually(func() bool {
return redisServer.DB(redisConfig.Database).Exists(CacheStorePrefix + "example.com")
}).Should(BeTrue())
ttl := redisServer.DB(redisConfig.Database).TTL(CacheStorePrefix + "example.com")
Expect(ttl.Seconds()).Should(BeNumerically("~", 123))
})
})
It("One new entry with default TTL should be persisted in the database", func() {
By("Database is empty", func() {
Eventually(func() []string {
return redisServer.DB(redisConfig.Database).Keys()
}).Should(BeEmpty())
})
By("publish new message with TTL = 0", func() {
res, err := util.NewMsgWithAnswer("example.com.", 0, dns.Type(dns.TypeA), "123.124.122.123")
Expect(err).Should(Succeed())
redisClient.PublishCache("example.com", res)
})
By("Database has one entry with default TTL", func() {
Eventually(func() bool {
return redisServer.DB(redisConfig.Database).Exists(CacheStorePrefix + "example.com")
}).Should(BeTrue())
ttl := redisServer.DB(redisConfig.Database).TTL(CacheStorePrefix + "example.com")
Expect(ttl.Seconds()).Should(BeNumerically("~", defaultCacheTime.Seconds()))
})
})
})
When("Redis client publishes 'enabled' message", func() {
It("should propagate the message over redis", func() {
redisClient.PublishEnabled(&EnabledMessage{
State: true,
})
Eventually(func() map[string]int {
return redisServer.PubSubNumSub(SyncChannelName)
}).Should(HaveLen(1))
})
})
})
Describe("Receive message", func() {
When("'enabled' message is received", func() {
It("should propagate the message over the channel", func() {
var binState []byte
binState, err = json.Marshal(EnabledMessage{State: true})
Expect(err).Should(Succeed())
var id []byte
id, err = uuid.New().MarshalBinary()
Expect(err).Should(Succeed())
var binMsg []byte
binMsg, err = json.Marshal(redisMessage{
Type: messageTypeEnable,
Message: binState,
Client: id,
})
Expect(err).Should(Succeed())
lenE := len(redisClient.EnabledChannel)
rec := redisServer.Publish(SyncChannelName, string(binMsg))
Expect(rec).Should(Equal(1))
Eventually(func() chan *EnabledMessage {
return redisClient.EnabledChannel
}).Should(HaveLen(lenE + 1))
})
})
When("'cache' message is received", func() {
It("should propagate the message over the channel", func() {
res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
Expect(err).Should(Succeed())
var binState []byte
binState, err = res.Pack()
Expect(err).Should(Succeed())
var id []byte
id, err = uuid.New().MarshalBinary()
Expect(err).Should(Succeed())
var binMsg []byte
binMsg, err = json.Marshal(redisMessage{
Key: "example.com",
Type: messageTypeCache,
Message: binState,
Client: id,
})
Expect(err).Should(Succeed())
lenE := len(redisClient.CacheChannel)
rec := redisServer.Publish(SyncChannelName, string(binMsg))
Expect(rec).Should(Equal(1))
Eventually(func() chan *CacheMessage {
return redisClient.CacheChannel
}).Should(HaveLen(lenE + 1))
})
})
When("wrong data is received", func() {
It("should not propagate the message over the channel if data is wrong", func() {
var id []byte
id, err = uuid.New().MarshalBinary()
Expect(err).Should(Succeed())
var binMsg []byte
binMsg, err = json.Marshal(redisMessage{
Key: "unknown",
Type: messageTypeCache,
Message: []byte("test"),
Client: id,
})
Expect(err).Should(Succeed())
lenE := len(redisClient.EnabledChannel)
lenC := len(redisClient.CacheChannel)
rec := redisServer.Publish(SyncChannelName, string(binMsg))
Expect(rec).Should(Equal(1))
Eventually(func() chan *EnabledMessage {
return redisClient.EnabledChannel
}).Should(HaveLen(lenE))
Eventually(func() chan *CacheMessage {
return redisClient.CacheChannel
}).Should(HaveLen(lenC))
})
It("should not propagate the message over the channel if type is wrong", func() {
var id []byte
id, err = uuid.New().MarshalBinary()
Expect(err).Should(Succeed())
var binMsg []byte
binMsg, err = json.Marshal(redisMessage{
Key: "unknown",
Type: 99,
Message: []byte("test"),
Client: id,
})
Expect(err).Should(Succeed())
lenE := len(redisClient.EnabledChannel)
lenC := len(redisClient.CacheChannel)
rec := redisServer.Publish(SyncChannelName, string(binMsg))
Expect(rec).Should(Equal(1))
time.Sleep(2 * time.Second)
Eventually(func() chan *EnabledMessage {
return redisClient.EnabledChannel
}).Should(HaveLen(lenE))
Eventually(func() chan *CacheMessage {
return redisClient.CacheChannel
}).Should(HaveLen(lenC))
})
})
})
Describe("Read the redis cache and publish it to the channel", func() {
When("GetRedisCache is called with valid database entries", func() {
It("Should read data from Redis and propagate it via cache channel", func() {
By("Database is empty", func() {
Eventually(func() []string {
return redisServer.DB(redisConfig.Database).Keys()
}).Should(BeEmpty())
})
By("Put valid data in Redis by publishing the cache entry", func() {
var res *dns.Msg
res, err = util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
Expect(err).Should(Succeed())
redisClient.PublishCache("example.com", res)
})
By("Database has one entry now", func() {
Eventually(func() []string {
return redisServer.DB(redisConfig.Database).Keys()
}).Should(HaveLen(1))
})
By("call GetRedisCache - It should read one entry from redis and propagate it via channel", func() {
redisClient.GetRedisCache()
Eventually(redisClient.CacheChannel).Should(HaveLen(1))
})
})
})
When("GetRedisCache is called and database contains not valid entry", func() {
It("Should do nothing (only log error)", func() {
Expect(redisServer.DB(redisConfig.Database).Set(CacheStorePrefix+"test", "test")).Should(Succeed())
redisClient.GetRedisCache()
Consistently(redisClient.CacheChannel).Should(BeEmpty())
})
})
})
})

50
redis/rueidis.go Normal file
View File

@ -0,0 +1,50 @@
package redis
import (
"fmt"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/util"
"github.com/rueian/rueidis"
)
// New creates a new redis client
func NewRedisClient(cfg *config.RedisConfig) (rueidis.Client, error) {
// disable redis if no address is provided
if cfg == nil || len(cfg.Addresses) == 0 {
return nil, nil //nolint:nilnil
}
roption := rueidis.ClientOption{
InitAddress: cfg.Addresses,
Password: cfg.Password,
Username: cfg.Username,
SelectDB: cfg.Database,
ClientName: fmt.Sprintf("blocky-%s", util.HostnameString()),
ClientTrackingOptions: []string{"PREFIX", "blocky:", "BCAST"},
}
if len(cfg.SentinelMasterSet) > 0 {
roption.Sentinel = rueidis.SentinelOption{
Username: cfg.SentinelUsername,
Password: cfg.SentinelPassword,
MasterSet: cfg.SentinelMasterSet,
}
}
var err error
var client rueidis.Client
for i := 0; i < cfg.ConnectionAttempts; i++ {
client, err = rueidis.NewClient(roption)
if err == nil {
break
}
time.Sleep(time.Duration(cfg.ConnectionCooldown))
}
if err != nil {
return nil, err
}
return client, nil
}

View File

@ -28,7 +28,7 @@ type CachingResolver struct {
emitMetricEvents bool // disabled by Bootstrap
resultCache expirationcache.ExpiringCache[cacheValue]
resultCache expirationcache.ExpiringCache[dns.Msg]
prefetchingNameCache expirationcache.ExpiringCache[int]
redisClient *redis.Client
}
@ -55,30 +55,20 @@ func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetri
configureCaches(c, &cfg)
if c.redisClient != nil {
setupRedisCacheSubscriber(c)
c.redisClient.GetRedisCache()
}
return c
}
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
cleanupOption := expirationcache.WithCleanUpInterval[cacheValue](defaultCachingCleanUpInterval)
maxSizeOption := expirationcache.WithMaxSize[cacheValue](uint(cfg.MaxItemsCount))
rdb, err := redis.NewRedisClient(&config.GetConfig().Redis)
util.FatalOnError("can't create redis client", err)
if cfg.Prefetching {
c.prefetchingNameCache = expirationcache.NewCache(
expirationcache.WithCleanUpInterval[int](time.Minute),
expirationcache.WithMaxSize[int](uint(cfg.PrefetchMaxItemsCount)),
)
c.resultCache = expirationcache.NewCache(
cleanupOption,
maxSizeOption,
expirationcache.WithOnExpiredFn(c.onExpired),
)
if rdb != nil {
// redis
c.resultCache = expirationcache.NewRedisCache(rdb, "query")
} else {
// in-memory
cleanupOption := expirationcache.WithCleanUpInterval[dns.Msg](defaultCachingCleanUpInterval)
maxSizeOption := expirationcache.WithMaxSize[dns.Msg](uint(cfg.MaxItemsCount))
c.resultCache = expirationcache.NewCache(cleanupOption, maxSizeOption)
}
}
@ -162,14 +152,14 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
r.publishMetricsIfEnabled(evt.CachingResultCacheHit, domain)
if val.prefetch {
// Hit from prefetch cache
r.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, domain)
}
//if val.prefetch {
// // Hit from prefetch cache
// r.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, domain)
//}
resp := val.resultMsg.Copy()
resp := val.Copy()
resp.SetReply(request.Req)
resp.Rcode = val.resultMsg.Rcode
resp.Rcode = val.Rcode
// Adjust TTL
for _, rr := range resp.Answer {
@ -215,21 +205,15 @@ func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, log
func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, prefetch, publish bool) {
if response.Res.Rcode == dns.RcodeSuccess {
// put value into cache
r.resultCache.Put(cacheKey, &cacheValue{response.Res, prefetch}, r.adjustTTLs(response.Res.Answer))
r.resultCache.Put(cacheKey, response.Res, r.adjustTTLs(response.Res.Answer))
} else if response.Res.Rcode == dns.RcodeNameError {
if r.cfg.CacheTimeNegative > 0 {
// put negative cache if result code is NXDOMAIN
r.resultCache.Put(cacheKey, &cacheValue{response.Res, prefetch}, r.cfg.CacheTimeNegative.ToDuration())
r.resultCache.Put(cacheKey, response.Res, r.cfg.CacheTimeNegative.ToDuration())
}
}
r.publishMetricsIfEnabled(evt.CachingResultCacheChanged, r.resultCache.TotalCount())
if publish && r.redisClient != nil {
res := *response.Res
res.Answer = response.Res.Answer
r.redisClient.PublishCache(cacheKey, &res)
}
}
// adjustTTLs calculates and returns the max TTL (considers also the min and max cache time)

View File

@ -151,12 +151,7 @@ func NewServer(cfg *config.Config) (server *Server, err error) {
return nil, err
}
redisClient, redisErr := redis.New(&cfg.Redis)
if redisErr != nil && cfg.Redis.Required {
return nil, redisErr
}
queryResolver, queryError := createQueryResolver(cfg, bootstrap, redisClient)
queryResolver, queryError := createQueryResolver(cfg, bootstrap, nil)
if queryError != nil {
return nil, queryError
}

View File

@ -165,13 +165,7 @@ func FatalOnError(message string, err error) {
// GenerateCacheKey return cacheKey by query type/domain
func GenerateCacheKey(qType dns.Type, qName string) string {
const qTypeLength = 2
b := make([]byte, qTypeLength+len(qName))
binary.BigEndian.PutUint16(b, uint16(qType))
copy(b[2:], strings.ToLower(qName))
return string(b)
return dns.TypeToString[uint16(qType)] + ":" + qName
}
// ExtractCacheKey return query type/domain from cacheKey