mirror of https://github.com/0xERR0R/blocky.git
WIP: switch to rueidis
This commit is contained in:
parent
a05441f452
commit
e106719fcb
|
@ -0,0 +1,89 @@
|
|||
package expirationcache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/rueian/rueidis"
|
||||
)
|
||||
|
||||
type RedisCache struct {
|
||||
rdb rueidis.Client
|
||||
name string
|
||||
}
|
||||
|
||||
func Key(k ...string) string {
|
||||
return fmt.Sprintf("blocky:%s", strings.Join(k, ":"))
|
||||
}
|
||||
|
||||
func (r *RedisCache) cacheKey(key string) string {
|
||||
return Key("cache", r.name, key)
|
||||
}
|
||||
|
||||
func toSeconds(t time.Duration) int64 {
|
||||
return int64(t.Seconds())
|
||||
}
|
||||
|
||||
func NoResult(res rueidis.RedisResult) bool {
|
||||
err := res.Error()
|
||||
|
||||
return err != nil && rueidis.IsRedisNil(err)
|
||||
}
|
||||
|
||||
func (r *RedisCache) Put(key string, val *dns.Msg, expiration time.Duration) {
|
||||
b, err := val.Pack()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
cmd := r.rdb.B().Setex().Key(r.cacheKey(key)).Seconds(toSeconds(expiration)).Value(rueidis.BinaryString(b)).Build()
|
||||
r.rdb.Do(context.Background(), cmd).Error()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RedisCache) Get(key string) (val *dns.Msg, expiration time.Duration) {
|
||||
cmd := r.rdb.B().Get().Key(r.cacheKey(key)).Cache()
|
||||
resp := r.rdb.DoCache(context.Background(), cmd, 600*time.Second)
|
||||
if NoResult(resp) {
|
||||
return nil, 0
|
||||
}
|
||||
err := resp.Error()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
respStr, err := resp.ToString()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bytesVal := []byte(respStr)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
err = msg.Unpack(bytesVal)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return msg, time.Duration(resp.CacheTTL() * int64(time.Second))
|
||||
}
|
||||
|
||||
func (r *RedisCache) TotalCount() int {
|
||||
// TODO implement me
|
||||
return 0
|
||||
}
|
||||
|
||||
func (r *RedisCache) Clear() {
|
||||
// TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func NewRedisCache(rdb rueidis.Client, name string) ExpiringCache[dns.Msg] {
|
||||
return &RedisCache{
|
||||
rdb: rdb,
|
||||
name: name,
|
||||
}
|
||||
}
|
|
@ -3,23 +3,20 @@ package stringcache
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/rueian/rueidis"
|
||||
)
|
||||
|
||||
type RedisGroupedStringCache struct {
|
||||
rdb *redis.Client
|
||||
rdb rueidis.Client
|
||||
name string
|
||||
}
|
||||
|
||||
func NewRedisGroupedStringCache(name string) *RedisGroupedStringCache {
|
||||
func NewRedisGroupedStringCache(name string, rdb rueidis.Client) *RedisGroupedStringCache {
|
||||
return &RedisGroupedStringCache{
|
||||
name: name,
|
||||
rdb: redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Password: "", // no password set
|
||||
DB: 0, // use default DB
|
||||
}),
|
||||
rdb: rdb,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -28,25 +25,28 @@ func (r *RedisGroupedStringCache) cacheKey(groupName string) string {
|
|||
}
|
||||
|
||||
func (r *RedisGroupedStringCache) ElementCount(group string) int {
|
||||
return int(r.rdb.SCard(context.Background(), r.cacheKey(group)).Val())
|
||||
res, err := r.rdb.DoCache(context.Background(), r.rdb.B().Scard().Key(r.cacheKey(group)).Cache(), 600*time.Second).ToInt64()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return int(res)
|
||||
}
|
||||
|
||||
func (r *RedisGroupedStringCache) Contains(searchString string, groups []string) []string {
|
||||
cmds, err := r.rdb.Pipelined(context.Background(), func(pipeline redis.Pipeliner) error {
|
||||
for _, group := range groups {
|
||||
pipeline.SIsMember(context.Background(), r.cacheKey(group), searchString)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
var cmds []rueidis.CacheableTTL
|
||||
for _, group := range groups {
|
||||
cmds = append(cmds, rueidis.CT(r.rdb.B().Sismember().Key(r.cacheKey(group)).Member(searchString).Cache(), time.Second))
|
||||
}
|
||||
resps := r.rdb.DoMultiCache(context.Background(), cmds...)
|
||||
|
||||
var result []string
|
||||
|
||||
for ix, group := range groups {
|
||||
if cmds[ix].(*redis.BoolCmd).Val() {
|
||||
r, err := resps[ix].AsBool()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if r {
|
||||
result = append(result, group)
|
||||
}
|
||||
}
|
||||
|
@ -55,30 +55,26 @@ func (r *RedisGroupedStringCache) Contains(searchString string, groups []string)
|
|||
}
|
||||
|
||||
func (r *RedisGroupedStringCache) Refresh(group string) GroupFactory {
|
||||
pipeline := r.rdb.Pipeline()
|
||||
pipeline.Del(context.Background(), r.cacheKey(group))
|
||||
cmds := rueidis.Commands{r.rdb.B().Del().Key(r.cacheKey(group)).Build()}
|
||||
|
||||
f := &RedisGroupFactory{
|
||||
rdb: r.rdb,
|
||||
name: r.cacheKey(group),
|
||||
pipeline: pipeline,
|
||||
rdb: r.rdb,
|
||||
name: r.cacheKey(group),
|
||||
cmds: cmds,
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
type RedisGroupFactory struct {
|
||||
rdb *redis.Client
|
||||
name string
|
||||
pipeline redis.Pipeliner
|
||||
cnt int
|
||||
rdb rueidis.Client
|
||||
name string
|
||||
cmds rueidis.Commands
|
||||
cnt int
|
||||
}
|
||||
|
||||
func (r *RedisGroupFactory) AddEntry(entry string) {
|
||||
err := r.pipeline.SAdd(context.Background(), r.name, entry).Err()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
r.cmds = append(r.cmds, r.rdb.B().Sadd().Key(r.name).Member(entry).Build())
|
||||
|
||||
r.cnt++
|
||||
}
|
||||
|
@ -88,8 +84,5 @@ func (r *RedisGroupFactory) Count() int {
|
|||
}
|
||||
|
||||
func (r *RedisGroupFactory) Finish() {
|
||||
_, err := r.pipeline.Exec(context.Background())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_ = r.rdb.DoMulti(context.Background(), r.cmds...)
|
||||
}
|
||||
|
|
|
@ -10,9 +10,7 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("BlockingConfig", func() {
|
||||
var (
|
||||
cfg BlockingConfig
|
||||
)
|
||||
var cfg BlockingConfig
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
|
|
|
@ -232,16 +232,16 @@ type (
|
|||
|
||||
// RedisConfig configuration for the redis connection
|
||||
type RedisConfig struct {
|
||||
Address string `yaml:"address"`
|
||||
Username string `yaml:"username" default:""`
|
||||
Password string `yaml:"password" default:""`
|
||||
Database int `yaml:"database" default:"0"`
|
||||
Required bool `yaml:"required" default:"false"`
|
||||
ConnectionAttempts int `yaml:"connectionAttempts" default:"3"`
|
||||
ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"`
|
||||
SentinelUsername string `yaml:"sentinelUsername" default:""`
|
||||
SentinelPassword string `yaml:"sentinelPassword" default:""`
|
||||
SentinelAddresses []string `yaml:"sentinelAddresses"`
|
||||
Addresses []string `yaml:"addresses"`
|
||||
Username string `yaml:"username" default:""`
|
||||
Password string `yaml:"password" default:""`
|
||||
Database int `yaml:"database" default:"0"`
|
||||
SentinelUsername string `yaml:"sentinelUsername" default:""`
|
||||
SentinelPassword string `yaml:"sentinelPassword" default:""`
|
||||
SentinelMasterSet string `yaml:"sentinelMasterSet" default:""`
|
||||
ClientMaxCachingTime Duration `yaml:"clientMaxCachingTime" default:"1h"`
|
||||
ConnectionAttempts int `yaml:"connectionAttempts" default:"3"`
|
||||
ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"`
|
||||
}
|
||||
|
||||
type (
|
||||
|
|
3
go.mod
3
go.mod
|
@ -55,7 +55,7 @@ require (
|
|||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/go-logr/logr v1.2.3 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
||||
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect
|
||||
github.com/jackc/pgx/v5 v5.3.0 // indirect
|
||||
github.com/klauspost/compress v1.11.13 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
|
@ -66,6 +66,7 @@ require (
|
|||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.0-rc2 // indirect
|
||||
github.com/opencontainers/runc v1.1.3 // indirect
|
||||
github.com/rueian/rueidis v0.0.95 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/tools/cmd/cover v0.1.0-deprecated // indirect
|
||||
google.golang.org/genproto v0.0.0-20220617124728-180714bec0ad // indirect
|
||||
|
|
4
go.sum
4
go.sum
|
@ -229,6 +229,8 @@ github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hf
|
|||
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U=
|
||||
github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
|
@ -396,6 +398,8 @@ github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6L
|
|||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rueian/rueidis v0.0.95 h1:zrqNGlIH+YYokKxCCSEnOCIrjFVFb29IqNurQEVqXMc=
|
||||
github.com/rueian/rueidis v0.0.95/go.mod h1:Ziv7p67TsXd3zAcQ5+BDdQvOsUE9N68u6JcA7/rE+x8=
|
||||
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
|
|
|
@ -11,6 +11,10 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/redis"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
|
@ -70,11 +74,23 @@ func NewListCache(t ListCacheType, groupToLinks map[string][]string, refreshPeri
|
|||
processingConcurrency = defaultProcessingConcurrency
|
||||
}
|
||||
|
||||
b := &ListCache{
|
||||
groupedCache: stringcache.NewChainedGroupedCache(
|
||||
rdb, err := redis.NewRedisClient(&config.GetConfig().Redis)
|
||||
util.FatalOnError("can't create redis client", err)
|
||||
var groupedCache stringcache.GroupedStringCache
|
||||
if rdb != nil {
|
||||
// redis
|
||||
groupedCache = stringcache.NewChainedGroupedCache(
|
||||
stringcache.NewRedisGroupedStringCache(t.String(), rdb),
|
||||
stringcache.NewInMemoryGroupedRegexCache())
|
||||
} else {
|
||||
// in-memory
|
||||
groupedCache = stringcache.NewChainedGroupedCache(
|
||||
stringcache.NewInMemoryGroupedStringCache(),
|
||||
stringcache.NewInMemoryGroupedRegexCache(),
|
||||
),
|
||||
stringcache.NewInMemoryGroupedRegexCache())
|
||||
}
|
||||
|
||||
b := &ListCache{
|
||||
groupedCache: groupedCache,
|
||||
groupToLinks: groupToLinks,
|
||||
refreshPeriod: refreshPeriod,
|
||||
downloader: downloader,
|
||||
|
|
|
@ -9,10 +9,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/google/uuid"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
@ -65,67 +63,6 @@ type Client struct {
|
|||
EnabledChannel chan *EnabledMessage
|
||||
}
|
||||
|
||||
// New creates a new redis client
|
||||
func New(cfg *config.RedisConfig) (*Client, error) {
|
||||
// disable redis if no address is provided
|
||||
if cfg == nil || len(cfg.Address) == 0 {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
var rdb *redis.Client
|
||||
if len(cfg.SentinelAddresses) > 0 {
|
||||
rdb = redis.NewFailoverClient(&redis.FailoverOptions{
|
||||
MasterName: cfg.Address,
|
||||
SentinelUsername: cfg.Username,
|
||||
SentinelPassword: cfg.SentinelPassword,
|
||||
SentinelAddrs: cfg.SentinelAddresses,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.Database,
|
||||
MaxRetries: cfg.ConnectionAttempts,
|
||||
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
|
||||
})
|
||||
} else {
|
||||
rdb = redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Address,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.Database,
|
||||
MaxRetries: cfg.ConnectionAttempts,
|
||||
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
|
||||
})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := rdb.Ping(ctx).Result()
|
||||
if err == nil {
|
||||
var id []byte
|
||||
|
||||
id, err = uuid.New().MarshalBinary()
|
||||
if err == nil {
|
||||
// construct client
|
||||
res := &Client{
|
||||
config: cfg,
|
||||
client: rdb,
|
||||
l: log.PrefixedLog("redis"),
|
||||
ctx: ctx,
|
||||
id: id,
|
||||
sendBuffer: make(chan *bufferMessage, chanCap),
|
||||
CacheChannel: make(chan *CacheMessage, chanCap),
|
||||
EnabledChannel: make(chan *EnabledMessage, chanCap),
|
||||
}
|
||||
|
||||
// start channel handling go routine
|
||||
err = res.startup()
|
||||
|
||||
return res, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// PublishCache publish cache to redis async
|
||||
func (c *Client) PublishCache(key string, message *dns.Msg) {
|
||||
if len(key) > 0 && message != nil {
|
||||
|
|
|
@ -1,312 +0,0 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/creasty/defaults"
|
||||
"github.com/google/uuid"
|
||||
"github.com/miekg/dns"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var (
|
||||
redisServer *miniredis.Miniredis
|
||||
redisClient *Client
|
||||
redisConfig *config.RedisConfig
|
||||
err error
|
||||
)
|
||||
|
||||
var _ = Describe("Redis client", func() {
|
||||
BeforeEach(func() {
|
||||
redisServer, err = miniredis.Run()
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
DeferCleanup(redisServer.Close)
|
||||
|
||||
var rcfg config.RedisConfig
|
||||
err = defaults.Set(&rcfg)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
rcfg.Address = redisServer.Addr()
|
||||
redisConfig = &rcfg
|
||||
redisClient, err = New(redisConfig)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(redisClient).ShouldNot(BeNil())
|
||||
})
|
||||
Describe("Client creation", func() {
|
||||
When("redis configuration has no address", func() {
|
||||
It("should return nil without error", func() {
|
||||
var rcfg config.RedisConfig
|
||||
err = defaults.Set(&rcfg)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
Expect(New(&rcfg)).Should(BeNil())
|
||||
})
|
||||
})
|
||||
When("redis configuration has invalid address", func() {
|
||||
It("should fail with error", func() {
|
||||
var rcfg config.RedisConfig
|
||||
err = defaults.Set(&rcfg)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
rcfg.Address = "127.0.0.1:0"
|
||||
|
||||
_, err = New(&rcfg)
|
||||
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
When("redis configuration has invalid password", func() {
|
||||
It("should fail with error", func() {
|
||||
var rcfg config.RedisConfig
|
||||
err = defaults.Set(&rcfg)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
rcfg.Address = redisServer.Addr()
|
||||
rcfg.Password = "wrong"
|
||||
|
||||
_, err = New(&rcfg)
|
||||
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Publish message", func() {
|
||||
When("Redis client publishes 'cache' message", func() {
|
||||
It("One new entry with TTL > 0 should be persisted in the database", func() {
|
||||
By("Database is empty", func() {
|
||||
Eventually(func() []string {
|
||||
return redisServer.DB(redisConfig.Database).Keys()
|
||||
}).Should(BeEmpty())
|
||||
})
|
||||
|
||||
By("publish new message with TTL > 0", func() {
|
||||
res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
redisClient.PublishCache("example.com", res)
|
||||
})
|
||||
|
||||
By("Database has one entry with correct TTL", func() {
|
||||
Eventually(func() bool {
|
||||
return redisServer.DB(redisConfig.Database).Exists(CacheStorePrefix + "example.com")
|
||||
}).Should(BeTrue())
|
||||
|
||||
ttl := redisServer.DB(redisConfig.Database).TTL(CacheStorePrefix + "example.com")
|
||||
Expect(ttl.Seconds()).Should(BeNumerically("~", 123))
|
||||
})
|
||||
})
|
||||
|
||||
It("One new entry with default TTL should be persisted in the database", func() {
|
||||
By("Database is empty", func() {
|
||||
Eventually(func() []string {
|
||||
return redisServer.DB(redisConfig.Database).Keys()
|
||||
}).Should(BeEmpty())
|
||||
})
|
||||
|
||||
By("publish new message with TTL = 0", func() {
|
||||
res, err := util.NewMsgWithAnswer("example.com.", 0, dns.Type(dns.TypeA), "123.124.122.123")
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
redisClient.PublishCache("example.com", res)
|
||||
})
|
||||
|
||||
By("Database has one entry with default TTL", func() {
|
||||
Eventually(func() bool {
|
||||
return redisServer.DB(redisConfig.Database).Exists(CacheStorePrefix + "example.com")
|
||||
}).Should(BeTrue())
|
||||
|
||||
ttl := redisServer.DB(redisConfig.Database).TTL(CacheStorePrefix + "example.com")
|
||||
Expect(ttl.Seconds()).Should(BeNumerically("~", defaultCacheTime.Seconds()))
|
||||
})
|
||||
})
|
||||
})
|
||||
When("Redis client publishes 'enabled' message", func() {
|
||||
It("should propagate the message over redis", func() {
|
||||
redisClient.PublishEnabled(&EnabledMessage{
|
||||
State: true,
|
||||
})
|
||||
Eventually(func() map[string]int {
|
||||
return redisServer.PubSubNumSub(SyncChannelName)
|
||||
}).Should(HaveLen(1))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Receive message", func() {
|
||||
When("'enabled' message is received", func() {
|
||||
It("should propagate the message over the channel", func() {
|
||||
var binState []byte
|
||||
binState, err = json.Marshal(EnabledMessage{State: true})
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
var id []byte
|
||||
id, err = uuid.New().MarshalBinary()
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
var binMsg []byte
|
||||
binMsg, err = json.Marshal(redisMessage{
|
||||
Type: messageTypeEnable,
|
||||
Message: binState,
|
||||
Client: id,
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
lenE := len(redisClient.EnabledChannel)
|
||||
|
||||
rec := redisServer.Publish(SyncChannelName, string(binMsg))
|
||||
Expect(rec).Should(Equal(1))
|
||||
|
||||
Eventually(func() chan *EnabledMessage {
|
||||
return redisClient.EnabledChannel
|
||||
}).Should(HaveLen(lenE + 1))
|
||||
})
|
||||
})
|
||||
When("'cache' message is received", func() {
|
||||
It("should propagate the message over the channel", func() {
|
||||
res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
var binState []byte
|
||||
binState, err = res.Pack()
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
var id []byte
|
||||
id, err = uuid.New().MarshalBinary()
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
var binMsg []byte
|
||||
binMsg, err = json.Marshal(redisMessage{
|
||||
Key: "example.com",
|
||||
Type: messageTypeCache,
|
||||
Message: binState,
|
||||
Client: id,
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
lenE := len(redisClient.CacheChannel)
|
||||
|
||||
rec := redisServer.Publish(SyncChannelName, string(binMsg))
|
||||
Expect(rec).Should(Equal(1))
|
||||
|
||||
Eventually(func() chan *CacheMessage {
|
||||
return redisClient.CacheChannel
|
||||
}).Should(HaveLen(lenE + 1))
|
||||
})
|
||||
})
|
||||
When("wrong data is received", func() {
|
||||
It("should not propagate the message over the channel if data is wrong", func() {
|
||||
var id []byte
|
||||
id, err = uuid.New().MarshalBinary()
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
var binMsg []byte
|
||||
binMsg, err = json.Marshal(redisMessage{
|
||||
Key: "unknown",
|
||||
Type: messageTypeCache,
|
||||
Message: []byte("test"),
|
||||
Client: id,
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
lenE := len(redisClient.EnabledChannel)
|
||||
lenC := len(redisClient.CacheChannel)
|
||||
|
||||
rec := redisServer.Publish(SyncChannelName, string(binMsg))
|
||||
Expect(rec).Should(Equal(1))
|
||||
|
||||
Eventually(func() chan *EnabledMessage {
|
||||
return redisClient.EnabledChannel
|
||||
}).Should(HaveLen(lenE))
|
||||
|
||||
Eventually(func() chan *CacheMessage {
|
||||
return redisClient.CacheChannel
|
||||
}).Should(HaveLen(lenC))
|
||||
})
|
||||
It("should not propagate the message over the channel if type is wrong", func() {
|
||||
var id []byte
|
||||
id, err = uuid.New().MarshalBinary()
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
var binMsg []byte
|
||||
binMsg, err = json.Marshal(redisMessage{
|
||||
Key: "unknown",
|
||||
Type: 99,
|
||||
Message: []byte("test"),
|
||||
Client: id,
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
lenE := len(redisClient.EnabledChannel)
|
||||
lenC := len(redisClient.CacheChannel)
|
||||
|
||||
rec := redisServer.Publish(SyncChannelName, string(binMsg))
|
||||
Expect(rec).Should(Equal(1))
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
Eventually(func() chan *EnabledMessage {
|
||||
return redisClient.EnabledChannel
|
||||
}).Should(HaveLen(lenE))
|
||||
|
||||
Eventually(func() chan *CacheMessage {
|
||||
return redisClient.CacheChannel
|
||||
}).Should(HaveLen(lenC))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Read the redis cache and publish it to the channel", func() {
|
||||
When("GetRedisCache is called with valid database entries", func() {
|
||||
It("Should read data from Redis and propagate it via cache channel", func() {
|
||||
By("Database is empty", func() {
|
||||
Eventually(func() []string {
|
||||
return redisServer.DB(redisConfig.Database).Keys()
|
||||
}).Should(BeEmpty())
|
||||
})
|
||||
|
||||
By("Put valid data in Redis by publishing the cache entry", func() {
|
||||
var res *dns.Msg
|
||||
|
||||
res, err = util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
redisClient.PublishCache("example.com", res)
|
||||
})
|
||||
|
||||
By("Database has one entry now", func() {
|
||||
Eventually(func() []string {
|
||||
return redisServer.DB(redisConfig.Database).Keys()
|
||||
}).Should(HaveLen(1))
|
||||
})
|
||||
|
||||
By("call GetRedisCache - It should read one entry from redis and propagate it via channel", func() {
|
||||
redisClient.GetRedisCache()
|
||||
|
||||
Eventually(redisClient.CacheChannel).Should(HaveLen(1))
|
||||
})
|
||||
})
|
||||
})
|
||||
When("GetRedisCache is called and database contains not valid entry", func() {
|
||||
It("Should do nothing (only log error)", func() {
|
||||
Expect(redisServer.DB(redisConfig.Database).Set(CacheStorePrefix+"test", "test")).Should(Succeed())
|
||||
redisClient.GetRedisCache()
|
||||
Consistently(redisClient.CacheChannel).Should(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,50 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/rueian/rueidis"
|
||||
)
|
||||
|
||||
// New creates a new redis client
|
||||
func NewRedisClient(cfg *config.RedisConfig) (rueidis.Client, error) {
|
||||
// disable redis if no address is provided
|
||||
if cfg == nil || len(cfg.Addresses) == 0 {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
roption := rueidis.ClientOption{
|
||||
InitAddress: cfg.Addresses,
|
||||
Password: cfg.Password,
|
||||
Username: cfg.Username,
|
||||
SelectDB: cfg.Database,
|
||||
ClientName: fmt.Sprintf("blocky-%s", util.HostnameString()),
|
||||
ClientTrackingOptions: []string{"PREFIX", "blocky:", "BCAST"},
|
||||
}
|
||||
|
||||
if len(cfg.SentinelMasterSet) > 0 {
|
||||
roption.Sentinel = rueidis.SentinelOption{
|
||||
Username: cfg.SentinelUsername,
|
||||
Password: cfg.SentinelPassword,
|
||||
MasterSet: cfg.SentinelMasterSet,
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
var client rueidis.Client
|
||||
for i := 0; i < cfg.ConnectionAttempts; i++ {
|
||||
client, err = rueidis.NewClient(roption)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Duration(cfg.ConnectionCooldown))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
|
@ -28,7 +28,7 @@ type CachingResolver struct {
|
|||
|
||||
emitMetricEvents bool // disabled by Bootstrap
|
||||
|
||||
resultCache expirationcache.ExpiringCache[cacheValue]
|
||||
resultCache expirationcache.ExpiringCache[dns.Msg]
|
||||
prefetchingNameCache expirationcache.ExpiringCache[int]
|
||||
redisClient *redis.Client
|
||||
}
|
||||
|
@ -55,30 +55,20 @@ func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetri
|
|||
|
||||
configureCaches(c, &cfg)
|
||||
|
||||
if c.redisClient != nil {
|
||||
setupRedisCacheSubscriber(c)
|
||||
c.redisClient.GetRedisCache()
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
|
||||
cleanupOption := expirationcache.WithCleanUpInterval[cacheValue](defaultCachingCleanUpInterval)
|
||||
maxSizeOption := expirationcache.WithMaxSize[cacheValue](uint(cfg.MaxItemsCount))
|
||||
rdb, err := redis.NewRedisClient(&config.GetConfig().Redis)
|
||||
util.FatalOnError("can't create redis client", err)
|
||||
|
||||
if cfg.Prefetching {
|
||||
c.prefetchingNameCache = expirationcache.NewCache(
|
||||
expirationcache.WithCleanUpInterval[int](time.Minute),
|
||||
expirationcache.WithMaxSize[int](uint(cfg.PrefetchMaxItemsCount)),
|
||||
)
|
||||
|
||||
c.resultCache = expirationcache.NewCache(
|
||||
cleanupOption,
|
||||
maxSizeOption,
|
||||
expirationcache.WithOnExpiredFn(c.onExpired),
|
||||
)
|
||||
if rdb != nil {
|
||||
// redis
|
||||
c.resultCache = expirationcache.NewRedisCache(rdb, "query")
|
||||
} else {
|
||||
// in-memory
|
||||
cleanupOption := expirationcache.WithCleanUpInterval[dns.Msg](defaultCachingCleanUpInterval)
|
||||
maxSizeOption := expirationcache.WithMaxSize[dns.Msg](uint(cfg.MaxItemsCount))
|
||||
c.resultCache = expirationcache.NewCache(cleanupOption, maxSizeOption)
|
||||
}
|
||||
}
|
||||
|
@ -162,14 +152,14 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
|
|||
|
||||
r.publishMetricsIfEnabled(evt.CachingResultCacheHit, domain)
|
||||
|
||||
if val.prefetch {
|
||||
// Hit from prefetch cache
|
||||
r.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, domain)
|
||||
}
|
||||
//if val.prefetch {
|
||||
// // Hit from prefetch cache
|
||||
// r.publishMetricsIfEnabled(evt.CachingPrefetchCacheHit, domain)
|
||||
//}
|
||||
|
||||
resp := val.resultMsg.Copy()
|
||||
resp := val.Copy()
|
||||
resp.SetReply(request.Req)
|
||||
resp.Rcode = val.resultMsg.Rcode
|
||||
resp.Rcode = val.Rcode
|
||||
|
||||
// Adjust TTL
|
||||
for _, rr := range resp.Answer {
|
||||
|
@ -215,21 +205,15 @@ func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, log
|
|||
func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, prefetch, publish bool) {
|
||||
if response.Res.Rcode == dns.RcodeSuccess {
|
||||
// put value into cache
|
||||
r.resultCache.Put(cacheKey, &cacheValue{response.Res, prefetch}, r.adjustTTLs(response.Res.Answer))
|
||||
r.resultCache.Put(cacheKey, response.Res, r.adjustTTLs(response.Res.Answer))
|
||||
} else if response.Res.Rcode == dns.RcodeNameError {
|
||||
if r.cfg.CacheTimeNegative > 0 {
|
||||
// put negative cache if result code is NXDOMAIN
|
||||
r.resultCache.Put(cacheKey, &cacheValue{response.Res, prefetch}, r.cfg.CacheTimeNegative.ToDuration())
|
||||
r.resultCache.Put(cacheKey, response.Res, r.cfg.CacheTimeNegative.ToDuration())
|
||||
}
|
||||
}
|
||||
|
||||
r.publishMetricsIfEnabled(evt.CachingResultCacheChanged, r.resultCache.TotalCount())
|
||||
|
||||
if publish && r.redisClient != nil {
|
||||
res := *response.Res
|
||||
res.Answer = response.Res.Answer
|
||||
r.redisClient.PublishCache(cacheKey, &res)
|
||||
}
|
||||
}
|
||||
|
||||
// adjustTTLs calculates and returns the max TTL (considers also the min and max cache time)
|
||||
|
|
|
@ -151,12 +151,7 @@ func NewServer(cfg *config.Config) (server *Server, err error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
redisClient, redisErr := redis.New(&cfg.Redis)
|
||||
if redisErr != nil && cfg.Redis.Required {
|
||||
return nil, redisErr
|
||||
}
|
||||
|
||||
queryResolver, queryError := createQueryResolver(cfg, bootstrap, redisClient)
|
||||
queryResolver, queryError := createQueryResolver(cfg, bootstrap, nil)
|
||||
if queryError != nil {
|
||||
return nil, queryError
|
||||
}
|
||||
|
|
|
@ -165,13 +165,7 @@ func FatalOnError(message string, err error) {
|
|||
|
||||
// GenerateCacheKey return cacheKey by query type/domain
|
||||
func GenerateCacheKey(qType dns.Type, qName string) string {
|
||||
const qTypeLength = 2
|
||||
b := make([]byte, qTypeLength+len(qName))
|
||||
|
||||
binary.BigEndian.PutUint16(b, uint16(qType))
|
||||
copy(b[2:], strings.ToLower(qName))
|
||||
|
||||
return string(b)
|
||||
return dns.TypeToString[uint16(qType)] + ":" + qName
|
||||
}
|
||||
|
||||
// ExtractCacheKey return query type/domain from cacheKey
|
||||
|
|
Loading…
Reference in New Issue