diff --git a/cache/expirationcache/redis_expiration_cache.go b/cache/expirationcache/redis_expiration_cache.go new file mode 100644 index 00000000..22f0bdd3 --- /dev/null +++ b/cache/expirationcache/redis_expiration_cache.go @@ -0,0 +1,89 @@ +package expirationcache + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/miekg/dns" + "github.com/rueian/rueidis" +) + +type RedisCache struct { + rdb rueidis.Client + name string +} + +func Key(k ...string) string { + return fmt.Sprintf("blocky:%s", strings.Join(k, ":")) +} + +func (r *RedisCache) cacheKey(key string) string { + return Key("cache", r.name, key) +} + +func toSeconds(t time.Duration) int64 { + return int64(t.Seconds()) +} + +func NoResult(res rueidis.RedisResult) bool { + err := res.Error() + + return err != nil && rueidis.IsRedisNil(err) +} + +func (r *RedisCache) Put(key string, val *dns.Msg, expiration time.Duration) { + b, err := val.Pack() + if err != nil { + panic(err) + } + cmd := r.rdb.B().Setex().Key(r.cacheKey(key)).Seconds(toSeconds(expiration)).Value(rueidis.BinaryString(b)).Build() + r.rdb.Do(context.Background(), cmd).Error() + if err != nil { + panic(err) + } +} + +func (r *RedisCache) Get(key string) (val *dns.Msg, expiration time.Duration) { + cmd := r.rdb.B().Get().Key(r.cacheKey(key)).Cache() + resp := r.rdb.DoCache(context.Background(), cmd, 600*time.Second) + if NoResult(resp) { + return nil, 0 + } + err := resp.Error() + if err != nil { + panic(err) + } + + respStr, err := resp.ToString() + if err != nil { + panic(err) + } + bytesVal := []byte(respStr) + + msg := new(dns.Msg) + err = msg.Unpack(bytesVal) + if err != nil { + panic(err) + } + + return msg, time.Duration(resp.CacheTTL() * int64(time.Second)) +} + +func (r *RedisCache) TotalCount() int { + // TODO implement me + return 0 +} + +func (r *RedisCache) Clear() { + // TODO implement me + +} + +func NewRedisCache(rdb rueidis.Client, name string) ExpiringCache[dns.Msg] { + return &RedisCache{ + rdb: rdb, + name: name, + } +} diff --git a/cache/stringcache/redis_grouped_cache.go b/cache/stringcache/redis_grouped_cache.go index c65f4175..e6a83ea3 100644 --- a/cache/stringcache/redis_grouped_cache.go +++ b/cache/stringcache/redis_grouped_cache.go @@ -3,23 +3,20 @@ package stringcache import ( "context" "fmt" + "time" - "github.com/go-redis/redis/v8" + "github.com/rueian/rueidis" ) type RedisGroupedStringCache struct { - rdb *redis.Client + rdb rueidis.Client name string } -func NewRedisGroupedStringCache(name string) *RedisGroupedStringCache { +func NewRedisGroupedStringCache(name string, rdb rueidis.Client) *RedisGroupedStringCache { return &RedisGroupedStringCache{ name: name, - rdb: redis.NewClient(&redis.Options{ - Addr: "localhost:6379", - Password: "", // no password set - DB: 0, // use default DB - }), + rdb: rdb, } } @@ -28,25 +25,28 @@ func (r *RedisGroupedStringCache) cacheKey(groupName string) string { } func (r *RedisGroupedStringCache) ElementCount(group string) int { - return int(r.rdb.SCard(context.Background(), r.cacheKey(group)).Val()) + res, err := r.rdb.DoCache(context.Background(), r.rdb.B().Scard().Key(r.cacheKey(group)).Cache(), 600*time.Second).ToInt64() + if err != nil { + return 0 + } + return int(res) } func (r *RedisGroupedStringCache) Contains(searchString string, groups []string) []string { - cmds, err := r.rdb.Pipelined(context.Background(), func(pipeline redis.Pipeliner) error { - for _, group := range groups { - pipeline.SIsMember(context.Background(), r.cacheKey(group), searchString) - } - - return nil - }) - if err != nil { - panic(err) + var cmds []rueidis.CacheableTTL + for _, group := range groups { + cmds = append(cmds, rueidis.CT(r.rdb.B().Sismember().Key(r.cacheKey(group)).Member(searchString).Cache(), time.Second)) } + resps := r.rdb.DoMultiCache(context.Background(), cmds...) var result []string for ix, group := range groups { - if cmds[ix].(*redis.BoolCmd).Val() { + r, err := resps[ix].AsBool() + if err != nil { + panic(err) + } + if r { result = append(result, group) } } @@ -55,30 +55,26 @@ func (r *RedisGroupedStringCache) Contains(searchString string, groups []string) } func (r *RedisGroupedStringCache) Refresh(group string) GroupFactory { - pipeline := r.rdb.Pipeline() - pipeline.Del(context.Background(), r.cacheKey(group)) + cmds := rueidis.Commands{r.rdb.B().Del().Key(r.cacheKey(group)).Build()} f := &RedisGroupFactory{ - rdb: r.rdb, - name: r.cacheKey(group), - pipeline: pipeline, + rdb: r.rdb, + name: r.cacheKey(group), + cmds: cmds, } return f } type RedisGroupFactory struct { - rdb *redis.Client - name string - pipeline redis.Pipeliner - cnt int + rdb rueidis.Client + name string + cmds rueidis.Commands + cnt int } func (r *RedisGroupFactory) AddEntry(entry string) { - err := r.pipeline.SAdd(context.Background(), r.name, entry).Err() - if err != nil { - panic(err) - } + r.cmds = append(r.cmds, r.rdb.B().Sadd().Key(r.name).Member(entry).Build()) r.cnt++ } @@ -88,8 +84,5 @@ func (r *RedisGroupFactory) Count() int { } func (r *RedisGroupFactory) Finish() { - _, err := r.pipeline.Exec(context.Background()) - if err != nil { - panic(err) - } + _ = r.rdb.DoMulti(context.Background(), r.cmds...) } diff --git a/config/blocking_test.go b/config/blocking_test.go index b806b028..59258525 100644 --- a/config/blocking_test.go +++ b/config/blocking_test.go @@ -10,9 +10,7 @@ import ( ) var _ = Describe("BlockingConfig", func() { - var ( - cfg BlockingConfig - ) + var cfg BlockingConfig suiteBeforeEach() diff --git a/config/config.go b/config/config.go index 455757ed..0ed637b6 100644 --- a/config/config.go +++ b/config/config.go @@ -232,16 +232,16 @@ type ( // RedisConfig configuration for the redis connection type RedisConfig struct { - Address string `yaml:"address"` - Username string `yaml:"username" default:""` - Password string `yaml:"password" default:""` - Database int `yaml:"database" default:"0"` - Required bool `yaml:"required" default:"false"` - ConnectionAttempts int `yaml:"connectionAttempts" default:"3"` - ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"` - SentinelUsername string `yaml:"sentinelUsername" default:""` - SentinelPassword string `yaml:"sentinelPassword" default:""` - SentinelAddresses []string `yaml:"sentinelAddresses"` + Addresses []string `yaml:"addresses"` + Username string `yaml:"username" default:""` + Password string `yaml:"password" default:""` + Database int `yaml:"database" default:"0"` + SentinelUsername string `yaml:"sentinelUsername" default:""` + SentinelPassword string `yaml:"sentinelPassword" default:""` + SentinelMasterSet string `yaml:"sentinelMasterSet" default:""` + ClientMaxCachingTime Duration `yaml:"clientMaxCachingTime" default:"1h"` + ConnectionAttempts int `yaml:"connectionAttempts" default:"3"` + ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"` } type ( diff --git a/go.mod b/go.mod index 62bf6014..3eebeb7a 100644 --- a/go.mod +++ b/go.mod @@ -55,7 +55,7 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/go-logr/logr v1.2.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect + github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect github.com/jackc/pgx/v5 v5.3.0 // indirect github.com/klauspost/compress v1.11.13 // indirect github.com/magiconair/properties v1.8.7 // indirect @@ -66,6 +66,7 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0-rc2 // indirect github.com/opencontainers/runc v1.1.3 // indirect + github.com/rueian/rueidis v0.0.95 // indirect golang.org/x/sync v0.1.0 // indirect golang.org/x/tools/cmd/cover v0.1.0-deprecated // indirect google.golang.org/genproto v0.0.0-20220617124728-180714bec0ad // indirect diff --git a/go.sum b/go.sum index c72538d3..f7e085f8 100644 --- a/go.sum +++ b/go.sum @@ -229,6 +229,8 @@ github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U= +github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= @@ -396,6 +398,8 @@ github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6L github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rueian/rueidis v0.0.95 h1:zrqNGlIH+YYokKxCCSEnOCIrjFVFb29IqNurQEVqXMc= +github.com/rueian/rueidis v0.0.95/go.mod h1:Ziv7p67TsXd3zAcQ5+BDdQvOsUE9N68u6JcA7/rE+x8= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= diff --git a/lists/list_cache.go b/lists/list_cache.go index e5dedcf1..73a76920 100644 --- a/lists/list_cache.go +++ b/lists/list_cache.go @@ -11,6 +11,10 @@ import ( "strings" "time" + "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/redis" + "github.com/0xERR0R/blocky/util" + "github.com/hashicorp/go-multierror" "github.com/sirupsen/logrus" @@ -70,11 +74,23 @@ func NewListCache(t ListCacheType, groupToLinks map[string][]string, refreshPeri processingConcurrency = defaultProcessingConcurrency } - b := &ListCache{ - groupedCache: stringcache.NewChainedGroupedCache( + rdb, err := redis.NewRedisClient(&config.GetConfig().Redis) + util.FatalOnError("can't create redis client", err) + var groupedCache stringcache.GroupedStringCache + if rdb != nil { + // redis + groupedCache = stringcache.NewChainedGroupedCache( + stringcache.NewRedisGroupedStringCache(t.String(), rdb), + stringcache.NewInMemoryGroupedRegexCache()) + } else { + // in-memory + groupedCache = stringcache.NewChainedGroupedCache( stringcache.NewInMemoryGroupedStringCache(), - stringcache.NewInMemoryGroupedRegexCache(), - ), + stringcache.NewInMemoryGroupedRegexCache()) + } + + b := &ListCache{ + groupedCache: groupedCache, groupToLinks: groupToLinks, refreshPeriod: refreshPeriod, downloader: downloader, diff --git a/redis/redis.go b/redis/redis.go index adf22f1e..291dc8db 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -9,10 +9,8 @@ import ( "time" "github.com/0xERR0R/blocky/config" - "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/go-redis/redis/v8" - "github.com/google/uuid" "github.com/miekg/dns" "github.com/sirupsen/logrus" ) @@ -65,67 +63,6 @@ type Client struct { EnabledChannel chan *EnabledMessage } -// New creates a new redis client -func New(cfg *config.RedisConfig) (*Client, error) { - // disable redis if no address is provided - if cfg == nil || len(cfg.Address) == 0 { - return nil, nil //nolint:nilnil - } - - var rdb *redis.Client - if len(cfg.SentinelAddresses) > 0 { - rdb = redis.NewFailoverClient(&redis.FailoverOptions{ - MasterName: cfg.Address, - SentinelUsername: cfg.Username, - SentinelPassword: cfg.SentinelPassword, - SentinelAddrs: cfg.SentinelAddresses, - Username: cfg.Username, - Password: cfg.Password, - DB: cfg.Database, - MaxRetries: cfg.ConnectionAttempts, - MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(), - }) - } else { - rdb = redis.NewClient(&redis.Options{ - Addr: cfg.Address, - Username: cfg.Username, - Password: cfg.Password, - DB: cfg.Database, - MaxRetries: cfg.ConnectionAttempts, - MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(), - }) - } - - ctx := context.Background() - - _, err := rdb.Ping(ctx).Result() - if err == nil { - var id []byte - - id, err = uuid.New().MarshalBinary() - if err == nil { - // construct client - res := &Client{ - config: cfg, - client: rdb, - l: log.PrefixedLog("redis"), - ctx: ctx, - id: id, - sendBuffer: make(chan *bufferMessage, chanCap), - CacheChannel: make(chan *CacheMessage, chanCap), - EnabledChannel: make(chan *EnabledMessage, chanCap), - } - - // start channel handling go routine - err = res.startup() - - return res, err - } - } - - return nil, err -} - // PublishCache publish cache to redis async func (c *Client) PublishCache(key string, message *dns.Msg) { if len(key) > 0 && message != nil { diff --git a/redis/redis_test.go b/redis/redis_test.go deleted file mode 100644 index 9ddf6185..00000000 --- a/redis/redis_test.go +++ /dev/null @@ -1,312 +0,0 @@ -package redis - -import ( - "encoding/json" - "time" - - "github.com/0xERR0R/blocky/config" - "github.com/0xERR0R/blocky/util" - "github.com/alicebob/miniredis/v2" - "github.com/creasty/defaults" - "github.com/google/uuid" - "github.com/miekg/dns" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var ( - redisServer *miniredis.Miniredis - redisClient *Client - redisConfig *config.RedisConfig - err error -) - -var _ = Describe("Redis client", func() { - BeforeEach(func() { - redisServer, err = miniredis.Run() - - Expect(err).Should(Succeed()) - - DeferCleanup(redisServer.Close) - - var rcfg config.RedisConfig - err = defaults.Set(&rcfg) - - Expect(err).Should(Succeed()) - - rcfg.Address = redisServer.Addr() - redisConfig = &rcfg - redisClient, err = New(redisConfig) - - Expect(err).Should(Succeed()) - Expect(redisClient).ShouldNot(BeNil()) - }) - Describe("Client creation", func() { - When("redis configuration has no address", func() { - It("should return nil without error", func() { - var rcfg config.RedisConfig - err = defaults.Set(&rcfg) - - Expect(err).Should(Succeed()) - - Expect(New(&rcfg)).Should(BeNil()) - }) - }) - When("redis configuration has invalid address", func() { - It("should fail with error", func() { - var rcfg config.RedisConfig - err = defaults.Set(&rcfg) - Expect(err).Should(Succeed()) - - rcfg.Address = "127.0.0.1:0" - - _, err = New(&rcfg) - - Expect(err).Should(HaveOccurred()) - }) - }) - When("redis configuration has invalid password", func() { - It("should fail with error", func() { - var rcfg config.RedisConfig - err = defaults.Set(&rcfg) - Expect(err).Should(Succeed()) - - rcfg.Address = redisServer.Addr() - rcfg.Password = "wrong" - - _, err = New(&rcfg) - - Expect(err).Should(HaveOccurred()) - }) - }) - }) - - Describe("Publish message", func() { - When("Redis client publishes 'cache' message", func() { - It("One new entry with TTL > 0 should be persisted in the database", func() { - By("Database is empty", func() { - Eventually(func() []string { - return redisServer.DB(redisConfig.Database).Keys() - }).Should(BeEmpty()) - }) - - By("publish new message with TTL > 0", func() { - res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123") - - Expect(err).Should(Succeed()) - - redisClient.PublishCache("example.com", res) - }) - - By("Database has one entry with correct TTL", func() { - Eventually(func() bool { - return redisServer.DB(redisConfig.Database).Exists(CacheStorePrefix + "example.com") - }).Should(BeTrue()) - - ttl := redisServer.DB(redisConfig.Database).TTL(CacheStorePrefix + "example.com") - Expect(ttl.Seconds()).Should(BeNumerically("~", 123)) - }) - }) - - It("One new entry with default TTL should be persisted in the database", func() { - By("Database is empty", func() { - Eventually(func() []string { - return redisServer.DB(redisConfig.Database).Keys() - }).Should(BeEmpty()) - }) - - By("publish new message with TTL = 0", func() { - res, err := util.NewMsgWithAnswer("example.com.", 0, dns.Type(dns.TypeA), "123.124.122.123") - - Expect(err).Should(Succeed()) - - redisClient.PublishCache("example.com", res) - }) - - By("Database has one entry with default TTL", func() { - Eventually(func() bool { - return redisServer.DB(redisConfig.Database).Exists(CacheStorePrefix + "example.com") - }).Should(BeTrue()) - - ttl := redisServer.DB(redisConfig.Database).TTL(CacheStorePrefix + "example.com") - Expect(ttl.Seconds()).Should(BeNumerically("~", defaultCacheTime.Seconds())) - }) - }) - }) - When("Redis client publishes 'enabled' message", func() { - It("should propagate the message over redis", func() { - redisClient.PublishEnabled(&EnabledMessage{ - State: true, - }) - Eventually(func() map[string]int { - return redisServer.PubSubNumSub(SyncChannelName) - }).Should(HaveLen(1)) - }) - }) - }) - - Describe("Receive message", func() { - When("'enabled' message is received", func() { - It("should propagate the message over the channel", func() { - var binState []byte - binState, err = json.Marshal(EnabledMessage{State: true}) - Expect(err).Should(Succeed()) - - var id []byte - id, err = uuid.New().MarshalBinary() - Expect(err).Should(Succeed()) - - var binMsg []byte - binMsg, err = json.Marshal(redisMessage{ - Type: messageTypeEnable, - Message: binState, - Client: id, - }) - Expect(err).Should(Succeed()) - - lenE := len(redisClient.EnabledChannel) - - rec := redisServer.Publish(SyncChannelName, string(binMsg)) - Expect(rec).Should(Equal(1)) - - Eventually(func() chan *EnabledMessage { - return redisClient.EnabledChannel - }).Should(HaveLen(lenE + 1)) - }) - }) - When("'cache' message is received", func() { - It("should propagate the message over the channel", func() { - res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123") - - Expect(err).Should(Succeed()) - - var binState []byte - binState, err = res.Pack() - Expect(err).Should(Succeed()) - - var id []byte - id, err = uuid.New().MarshalBinary() - Expect(err).Should(Succeed()) - - var binMsg []byte - binMsg, err = json.Marshal(redisMessage{ - Key: "example.com", - Type: messageTypeCache, - Message: binState, - Client: id, - }) - Expect(err).Should(Succeed()) - - lenE := len(redisClient.CacheChannel) - - rec := redisServer.Publish(SyncChannelName, string(binMsg)) - Expect(rec).Should(Equal(1)) - - Eventually(func() chan *CacheMessage { - return redisClient.CacheChannel - }).Should(HaveLen(lenE + 1)) - }) - }) - When("wrong data is received", func() { - It("should not propagate the message over the channel if data is wrong", func() { - var id []byte - id, err = uuid.New().MarshalBinary() - Expect(err).Should(Succeed()) - - var binMsg []byte - binMsg, err = json.Marshal(redisMessage{ - Key: "unknown", - Type: messageTypeCache, - Message: []byte("test"), - Client: id, - }) - Expect(err).Should(Succeed()) - - lenE := len(redisClient.EnabledChannel) - lenC := len(redisClient.CacheChannel) - - rec := redisServer.Publish(SyncChannelName, string(binMsg)) - Expect(rec).Should(Equal(1)) - - Eventually(func() chan *EnabledMessage { - return redisClient.EnabledChannel - }).Should(HaveLen(lenE)) - - Eventually(func() chan *CacheMessage { - return redisClient.CacheChannel - }).Should(HaveLen(lenC)) - }) - It("should not propagate the message over the channel if type is wrong", func() { - var id []byte - id, err = uuid.New().MarshalBinary() - Expect(err).Should(Succeed()) - - var binMsg []byte - binMsg, err = json.Marshal(redisMessage{ - Key: "unknown", - Type: 99, - Message: []byte("test"), - Client: id, - }) - Expect(err).Should(Succeed()) - - lenE := len(redisClient.EnabledChannel) - lenC := len(redisClient.CacheChannel) - - rec := redisServer.Publish(SyncChannelName, string(binMsg)) - Expect(rec).Should(Equal(1)) - - time.Sleep(2 * time.Second) - - Eventually(func() chan *EnabledMessage { - return redisClient.EnabledChannel - }).Should(HaveLen(lenE)) - - Eventually(func() chan *CacheMessage { - return redisClient.CacheChannel - }).Should(HaveLen(lenC)) - }) - }) - }) - - Describe("Read the redis cache and publish it to the channel", func() { - When("GetRedisCache is called with valid database entries", func() { - It("Should read data from Redis and propagate it via cache channel", func() { - By("Database is empty", func() { - Eventually(func() []string { - return redisServer.DB(redisConfig.Database).Keys() - }).Should(BeEmpty()) - }) - - By("Put valid data in Redis by publishing the cache entry", func() { - var res *dns.Msg - - res, err = util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123") - - Expect(err).Should(Succeed()) - - redisClient.PublishCache("example.com", res) - }) - - By("Database has one entry now", func() { - Eventually(func() []string { - return redisServer.DB(redisConfig.Database).Keys() - }).Should(HaveLen(1)) - }) - - By("call GetRedisCache - It should read one entry from redis and propagate it via channel", func() { - redisClient.GetRedisCache() - - Eventually(redisClient.CacheChannel).Should(HaveLen(1)) - }) - }) - }) - When("GetRedisCache is called and database contains not valid entry", func() { - It("Should do nothing (only log error)", func() { - Expect(redisServer.DB(redisConfig.Database).Set(CacheStorePrefix+"test", "test")).Should(Succeed()) - redisClient.GetRedisCache() - Consistently(redisClient.CacheChannel).Should(BeEmpty()) - }) - }) - }) -}) diff --git a/redis/rueidis.go b/redis/rueidis.go new file mode 100644 index 00000000..973bba5e --- /dev/null +++ b/redis/rueidis.go @@ -0,0 +1,50 @@ +package redis + +import ( + "fmt" + "time" + + "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/util" + "github.com/rueian/rueidis" +) + +// New creates a new redis client +func NewRedisClient(cfg *config.RedisConfig) (rueidis.Client, error) { + // disable redis if no address is provided + if cfg == nil || len(cfg.Addresses) == 0 { + return nil, nil //nolint:nilnil + } + + roption := rueidis.ClientOption{ + InitAddress: cfg.Addresses, + Password: cfg.Password, + Username: cfg.Username, + SelectDB: cfg.Database, + ClientName: fmt.Sprintf("blocky-%s", util.HostnameString()), + ClientTrackingOptions: []string{"PREFIX", "blocky:", "BCAST"}, + } + + if len(cfg.SentinelMasterSet) > 0 { + roption.Sentinel = rueidis.SentinelOption{ + Username: cfg.SentinelUsername, + Password: cfg.SentinelPassword, + MasterSet: cfg.SentinelMasterSet, + } + } + + var err error + var client rueidis.Client + for i := 0; i < cfg.ConnectionAttempts; i++ { + client, err = rueidis.NewClient(roption) + if err == nil { + break + } + time.Sleep(time.Duration(cfg.ConnectionCooldown)) + } + if err != nil { + return nil, err + } + + return client, nil +} diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 08091b9d..1b5f3438 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -28,7 +28,7 @@ type CachingResolver struct { emitMetricEvents bool // disabled by Bootstrap - resultCache expirationcache.ExpiringCache[cacheValue] + resultCache expirationcache.ExpiringCache[dns.Msg] prefetchingNameCache expirationcache.ExpiringCache[int] redisClient *redis.Client } @@ -55,30 +55,20 @@ func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetri configureCaches(c, &cfg) - if c.redisClient != nil { - setupRedisCacheSubscriber(c) - c.redisClient.GetRedisCache() - } - return c } func configureCaches(c *CachingResolver, cfg *config.CachingConfig) { - cleanupOption := expirationcache.WithCleanUpInterval[cacheValue](defaultCachingCleanUpInterval) - maxSizeOption := expirationcache.WithMaxSize[cacheValue](uint(cfg.MaxItemsCount)) + rdb, err := redis.NewRedisClient(&config.GetConfig().Redis) + util.FatalOnError("can't create redis client", err) - if cfg.Prefetching { - c.prefetchingNameCache = expirationcache.NewCache( - expirationcache.WithCleanUpInterval[int](time.Minute), - expirationcache.WithMaxSize[int](uint(cfg.PrefetchMaxItemsCount)), - ) - - c.resultCache = expirationcache.NewCache( - cleanupOption, - maxSizeOption, - expirationcache.WithOnExpiredFn(c.onExpired), - ) + if rdb != nil { + // redis + c.resultCache = expirationcache.NewRedisCache(rdb, "query") } else { + // in-memory + cleanupOption := expirationcache.WithCleanUpInterval[dns.Msg](defaultCachingCleanUpInterval) + maxSizeOption := expirationcache.WithMaxSize[dns.Msg](uint(cfg.MaxItemsCount)) c.resultCache = expirationcache.NewCache(cleanupOption, maxSizeOption) } } @@ -162,14 +152,14 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo r.publishMetricsIfEnabled(evt.CachingResultCacheHit, domain) - if val.prefetch { - // Hit from prefetch cache - r.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, domain) - } + //if val.prefetch { + // // Hit from prefetch cache + // r.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, domain) + //} - resp := val.resultMsg.Copy() + resp := val.Copy() resp.SetReply(request.Req) - resp.Rcode = val.resultMsg.Rcode + resp.Rcode = val.Rcode // Adjust TTL for _, rr := range resp.Answer { @@ -215,21 +205,15 @@ func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, log func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, prefetch, publish bool) { if response.Res.Rcode == dns.RcodeSuccess { // put value into cache - r.resultCache.Put(cacheKey, &cacheValue{response.Res, prefetch}, r.adjustTTLs(response.Res.Answer)) + r.resultCache.Put(cacheKey, response.Res, r.adjustTTLs(response.Res.Answer)) } else if response.Res.Rcode == dns.RcodeNameError { if r.cfg.CacheTimeNegative > 0 { // put negative cache if result code is NXDOMAIN - r.resultCache.Put(cacheKey, &cacheValue{response.Res, prefetch}, r.cfg.CacheTimeNegative.ToDuration()) + r.resultCache.Put(cacheKey, response.Res, r.cfg.CacheTimeNegative.ToDuration()) } } r.publishMetricsIfEnabled(evt.CachingResultCacheChanged, r.resultCache.TotalCount()) - - if publish && r.redisClient != nil { - res := *response.Res - res.Answer = response.Res.Answer - r.redisClient.PublishCache(cacheKey, &res) - } } // adjustTTLs calculates and returns the max TTL (considers also the min and max cache time) diff --git a/server/server.go b/server/server.go index 2d36101b..74de0357 100644 --- a/server/server.go +++ b/server/server.go @@ -151,12 +151,7 @@ func NewServer(cfg *config.Config) (server *Server, err error) { return nil, err } - redisClient, redisErr := redis.New(&cfg.Redis) - if redisErr != nil && cfg.Redis.Required { - return nil, redisErr - } - - queryResolver, queryError := createQueryResolver(cfg, bootstrap, redisClient) + queryResolver, queryError := createQueryResolver(cfg, bootstrap, nil) if queryError != nil { return nil, queryError } diff --git a/util/common.go b/util/common.go index b95ae05b..844581a5 100644 --- a/util/common.go +++ b/util/common.go @@ -165,13 +165,7 @@ func FatalOnError(message string, err error) { // GenerateCacheKey return cacheKey by query type/domain func GenerateCacheKey(qType dns.Type, qName string) string { - const qTypeLength = 2 - b := make([]byte, qTypeLength+len(qName)) - - binary.BigEndian.PutUint16(b, uint16(qType)) - copy(b[2:], strings.ToLower(qName)) - - return string(b) + return dns.TypeToString[uint16(qType)] + ":" + qName } // ExtractCacheKey return query type/domain from cacheKey