Refactoring Redis (#1271)

* RedisConfig -> Redis

* moved redis config to seperate file

* bugfix in config test during parallel processing

* implement config.Configurable in Redis config

* use Context in GetRedisCache

* use Context in New

* caching resolver test fix

* use Context in PublishEnabled

* use Context in getResponse

* remove ctx field

* bugfix in api interface test

* propperly close channels

* set ruler for go files from 80 to 111

* line break because function length is to long

* only execute redis.New if it is enabled in config

* stabilized flaky tests

* Update config/redis.go

Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>

* Update config/redis_test.go

Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>

* Update config/redis_test.go

Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>

* Update config/redis_test.go

Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>

* Update config/redis.go

Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>

* Update config/redis_test.go

Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>

* fix ruler

* redis test refactoring

* vscode setting cleanup

* removed else if chain

* Update redis_test.go

* context race fix

* test fail on missing seintinel servers

* cleanup context usage

* cleanup2

* context fixes

* added context util

* disabled nil context rule for tests

* copy paste error ctxSend -> CtxSend

* use util.CtxSend

* fixed comment

* fixed flaky test

* failsafe and tests

---------

Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>
This commit is contained in:
Kwitsch 2023-11-27 18:08:31 +01:00 committed by GitHub
parent 15bd383460
commit fda2dbe9df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 576 additions and 199 deletions

View File

@ -31,19 +31,13 @@
"GitHub.vscode-github-actions"
],
"settings": {
"go.lintFlags": ["--config=${containerWorkspaceFolder}/.golangci.yml"],
"go.lintFlags": [
"--config=${containerWorkspaceFolder}/.golangci.yml",
"--fast"
],
"go.alternateTools": {
"go-langserver": "gopls"
},
"[go]": {
"editor.defaultFormatter": "golang.go"
},
"[json][jsonc][github-actions-workflow]": {
"editor.defaultFormatter": "esbenp.prettier-vscode"
},
"[markdown]": {
"editor.defaultFormatter": "yzhang.markdown-all-in-one"
},
"markiscodecoverage.searchCriteria": "**/*.lcov",
"runItOn": {
"commands": [
@ -52,6 +46,15 @@
"cmd": "${workspaceRoot}/.devcontainer/scripts/runItOnGo.sh ${fileDirname} ${workspaceRoot}"
}
]
},
"[go]": {
"editor.defaultFormatter": "golang.go"
},
"[json][jsonc][github-actions-workflow]": {
"editor.defaultFormatter": "esbenp.prettier-vscode"
},
"[markdown]": {
"editor.defaultFormatter": "yzhang.markdown-all-in-one"
}
}
}

View File

@ -98,3 +98,7 @@ issues:
- gochecknoinits
- gochecknoglobals
- gosec
- path: _test\.go
linters:
- staticcheck
text: "SA1012:"

View File

@ -8,6 +8,7 @@
"source.organizeImports": true,
"source.fixAll": true
},
"editor.rulers": [120],
"go.showWelcome": false,
"go.survey.prompt": false,
"go.useLanguageServer": true,

View File

@ -29,8 +29,8 @@ type BlockingStatus struct {
// BlockingControl interface to control the blocking status
type BlockingControl interface {
EnableBlocking()
DisableBlocking(duration time.Duration, disableGroups []string) error
EnableBlocking(ctx context.Context)
DisableBlocking(ctx context.Context, duration time.Duration, disableGroups []string) error
BlockingStatus() BlockingStatus
}
@ -71,7 +71,7 @@ func NewOpenAPIInterfaceImpl(control BlockingControl,
}
}
func (i *OpenAPIInterfaceImpl) DisableBlocking(_ context.Context,
func (i *OpenAPIInterfaceImpl) DisableBlocking(ctx context.Context,
request DisableBlockingRequestObject,
) (DisableBlockingResponseObject, error) {
var (
@ -91,7 +91,7 @@ func (i *OpenAPIInterfaceImpl) DisableBlocking(_ context.Context,
groups = strings.Split(*request.Params.Groups, ",")
}
err = i.control.DisableBlocking(duration, groups)
err = i.control.DisableBlocking(ctx, duration, groups)
if err != nil {
return DisableBlocking400TextResponse(log.EscapeInput(err.Error())), nil
@ -100,9 +100,9 @@ func (i *OpenAPIInterfaceImpl) DisableBlocking(_ context.Context,
return DisableBlocking200Response{}, nil
}
func (i *OpenAPIInterfaceImpl) EnableBlocking(_ context.Context, _ EnableBlockingRequestObject,
func (i *OpenAPIInterfaceImpl) EnableBlocking(ctx context.Context, _ EnableBlockingRequestObject,
) (EnableBlockingResponseObject, error) {
i.control.EnableBlocking()
i.control.EnableBlocking(ctx)
return EnableBlocking200Response{}, nil
}

View File

@ -37,11 +37,11 @@ func (m *ListRefreshMock) RefreshLists() error {
return args.Error(0)
}
func (m *BlockingControlMock) EnableBlocking() {
func (m *BlockingControlMock) EnableBlocking(_ context.Context) {
_ = m.Called()
}
func (m *BlockingControlMock) DisableBlocking(t time.Duration, g []string) error {
func (m *BlockingControlMock) DisableBlocking(_ context.Context, t time.Duration, g []string) error {
args := m.Called(t, g)
return args.Error(0)

View File

@ -219,7 +219,7 @@ type Config struct {
Caching CachingConfig `yaml:"caching"`
QueryLog QueryLogConfig `yaml:"queryLog"`
Prometheus MetricsConfig `yaml:"prometheus"`
Redis RedisConfig `yaml:"redis"`
Redis Redis `yaml:"redis"`
Log log.Config `yaml:"log"`
Ports PortsConfig `yaml:"ports"`
MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"`
@ -280,20 +280,6 @@ type (
}
)
// RedisConfig configuration for the redis connection
type RedisConfig struct {
Address string `yaml:"address"`
Username string `yaml:"username" default:""`
Password string `yaml:"password" default:""`
Database int `yaml:"database" default:"0"`
Required bool `yaml:"required" default:"false"`
ConnectionAttempts int `yaml:"connectionAttempts" default:"3"`
ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"`
SentinelUsername string `yaml:"sentinelUsername" default:""`
SentinelPassword string `yaml:"sentinelPassword" default:""`
SentinelAddresses []string `yaml:"sentinelAddresses"`
}
type (
FQDNOnly = toEnable
EDE = toEnable

View File

@ -169,7 +169,7 @@ var _ = Describe("Config", func() {
err = writeConfigDir(tmpDir)
Expect(err).Should(Succeed())
_, err := LoadConfig(tmpDir.Path, true)
c, err = LoadConfig(tmpDir.Path, true)
Expect(err).Should(Succeed())
defaultTestFileConfig(c)

57
config/redis.go Normal file
View File

@ -0,0 +1,57 @@
package config
import (
"strings"
"github.com/sirupsen/logrus"
)
// Redis configuration for the redis connection
type Redis struct {
Address string `yaml:"address"`
Username string `yaml:"username" default:""`
Password string `yaml:"password" default:""`
Database int `yaml:"database" default:"0"`
Required bool `yaml:"required" default:"false"`
ConnectionAttempts int `yaml:"connectionAttempts" default:"3"`
ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"`
SentinelUsername string `yaml:"sentinelUsername" default:""`
SentinelPassword string `yaml:"sentinelPassword" default:""`
SentinelAddresses []string `yaml:"sentinelAddresses"`
}
// IsEnabled implements `config.Configurable`
func (c *Redis) IsEnabled() bool {
return c.Address != ""
}
// LogConfig implements `config.Configurable`
func (c *Redis) LogConfig(logger *logrus.Entry) {
if len(c.SentinelAddresses) == 0 {
logger.Info("address: ", c.Address)
}
logger.Info("username: ", c.Username)
logger.Info("password: ", obfuscatePassword(c.Password))
logger.Info("database: ", c.Database)
logger.Info("required: ", c.Required)
logger.Info("connectionAttempts: ", c.ConnectionAttempts)
logger.Info("connectionCooldown: ", c.ConnectionCooldown)
if len(c.SentinelAddresses) > 0 {
logger.Info("sentinel:")
logger.Info(" master: ", c.Address)
logger.Info(" username: ", c.SentinelUsername)
logger.Info(" password: ", obfuscatePassword(c.SentinelPassword))
logger.Info(" addresses:")
for _, addr := range c.SentinelAddresses {
logger.Info(" - ", addr)
}
}
}
// obfuscatePassword replaces all characters of a password except the first and last with *
func obfuscatePassword(pass string) string {
return strings.Repeat("*", len(pass))
}

104
config/redis_test.go Normal file
View File

@ -0,0 +1,104 @@
package config
import (
"github.com/0xERR0R/blocky/log"
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Redis", func() {
var (
c Redis
err error
)
suiteBeforeEach()
BeforeEach(func() {
err = defaults.Set(&c)
Expect(err).Should(Succeed())
})
Describe("IsEnabled", func() {
When("all fields are default", func() {
It("should be disabled", func() {
Expect(c.IsEnabled()).Should(BeFalse())
})
})
When("Address is set", func() {
BeforeEach(func() {
c.Address = "localhost:6379"
})
It("should be enabled", func() {
Expect(c.IsEnabled()).Should(BeTrue())
})
})
})
Describe("LogConfig", func() {
BeforeEach(func() {
logger, hook = log.NewMockEntry()
})
When("all fields are default", func() {
It("should log default values", func() {
c.LogConfig(logger)
Expect(hook.Messages).Should(
SatisfyAll(ContainElement(ContainSubstring("address: ")),
ContainElement(ContainSubstring("username: ")),
ContainElement(ContainSubstring("password: ")),
ContainElement(ContainSubstring("database: ")),
ContainElement(ContainSubstring("required: ")),
ContainElement(ContainSubstring("connectionAttempts: ")),
ContainElement(ContainSubstring("connectionCooldown: "))))
})
})
When("Address is set", func() {
BeforeEach(func() {
c.Address = "localhost:6379"
})
It("should log address", func() {
c.LogConfig(logger)
Expect(hook.Messages).Should(ContainElement(ContainSubstring("address: localhost:6379")))
})
})
When("SentinelAddresses is set", func() {
BeforeEach(func() {
c.SentinelAddresses = []string{"localhost:26379", "localhost:26380"}
})
It("should log sentinel addresses", func() {
c.LogConfig(logger)
Expect(hook.Messages).Should(
SatisfyAll(
ContainElement(ContainSubstring("sentinel:")),
ContainElement(ContainSubstring(" addresses:")),
ContainElement(ContainSubstring(" - localhost:26379")),
ContainElement(ContainSubstring(" - localhost:26380"))))
})
})
})
Describe("obfuscatePassword", func() {
When("password is empty", func() {
It("should return empty string", func() {
Expect(obfuscatePassword("")).Should(Equal(""))
})
})
When("password is not empty", func() {
It("should return obfuscated password", func() {
Expect(obfuscatePassword("test123")).Should(Equal("*******"))
})
})
})
})

View File

@ -12,6 +12,7 @@ import (
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/go-redis/redis/v8"
"github.com/google/uuid"
"github.com/miekg/dns"
@ -56,10 +57,9 @@ type EnabledMessage struct {
// Client for redis communication
type Client struct {
config *config.RedisConfig
config *config.Redis
client *redis.Client
l *logrus.Entry
ctx context.Context
id []byte
sendBuffer chan *bufferMessage
CacheChannel chan *CacheMessage
@ -67,15 +67,15 @@ type Client struct {
}
// New creates a new redis client
func New(cfg *config.RedisConfig) (*Client, error) {
func New(ctx context.Context, cfg *config.Redis) (*Client, error) {
// disable redis if no address is provided
if cfg == nil || len(cfg.Address) == 0 {
return nil, nil //nolint:nilnil
}
var rdb *redis.Client
var baseClient *redis.Client
if len(cfg.SentinelAddresses) > 0 {
rdb = redis.NewFailoverClient(&redis.FailoverOptions{
baseClient = redis.NewFailoverClient(&redis.FailoverOptions{
MasterName: cfg.Address,
SentinelUsername: cfg.Username,
SentinelPassword: cfg.SentinelPassword,
@ -87,7 +87,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
})
} else {
rdb = redis.NewClient(&redis.Options{
baseClient = redis.NewClient(&redis.Options{
Addr: cfg.Address,
Username: cfg.Username,
Password: cfg.Password,
@ -97,7 +97,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
})
}
ctx := context.Background()
rdb := baseClient.WithContext(ctx)
_, err := rdb.Ping(ctx).Result()
if err == nil {
@ -110,7 +110,6 @@ func New(cfg *config.RedisConfig) (*Client, error) {
config: cfg,
client: rdb,
l: log.PrefixedLog("redis"),
ctx: ctx,
id: id,
sendBuffer: make(chan *bufferMessage, chanCap),
CacheChannel: make(chan *CacheMessage, chanCap),
@ -118,7 +117,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
}
// start channel handling go routine
err = res.startup()
err = res.startup(ctx)
return res, err
}
@ -137,7 +136,7 @@ func (c *Client) PublishCache(key string, message *dns.Msg) {
}
}
func (c *Client) PublishEnabled(state *EnabledMessage) {
func (c *Client) PublishEnabled(ctx context.Context, state *EnabledMessage) {
binState, sErr := json.Marshal(state)
if sErr == nil {
binMsg, mErr := json.Marshal(redisMessage{
@ -147,22 +146,30 @@ func (c *Client) PublishEnabled(state *EnabledMessage) {
})
if mErr == nil {
c.client.Publish(c.ctx, SyncChannelName, binMsg)
c.client.Publish(ctx, SyncChannelName, binMsg)
}
}
}
// GetRedisCache reads the redis cache and publish it to the channel
func (c *Client) GetRedisCache() {
func (c *Client) GetRedisCache(ctx context.Context) {
c.l.Debug("GetRedisCache")
go func() {
iter := c.client.Scan(c.ctx, 0, prefixKey("*"), 0).Iterator()
for iter.Next(c.ctx) {
response, err := c.getResponse(iter.Val())
iter := c.client.Scan(ctx, 0, prefixKey("*"), 0).Iterator()
if err := iter.Err(); err != nil {
c.l.Error("GetRedisCache ", err)
return
}
for iter.Next(ctx) {
response, err := c.getResponse(ctx, iter.Val())
if err == nil {
if response != nil {
c.CacheChannel <- response
if !util.CtxSend(ctx, c.CacheChannel, response) {
return
}
}
} else {
c.l.Error("GetRedisCache ", err)
@ -172,10 +179,10 @@ func (c *Client) GetRedisCache() {
}
// startup starts a new goroutine for subscription and translation
func (c *Client) startup() error {
ps := c.client.Subscribe(c.ctx, SyncChannelName)
func (c *Client) startup(ctx context.Context) error {
ps := c.client.Subscribe(ctx, SyncChannelName)
_, err := ps.Receive(c.ctx)
_, err := ps.Receive(ctx)
if err == nil {
go func() {
for {
@ -186,11 +193,16 @@ func (c *Client) startup() error {
if msg != nil && len(msg.Payload) > 0 {
// message is not empty
c.processReceivedMessage(msg)
c.processReceivedMessage(ctx, msg)
}
// publish message from buffer
case s := <-c.sendBuffer:
c.publishMessageFromBuffer(s)
c.publishMessageFromBuffer(ctx, s)
// context is done
case <-ctx.Done():
c.client.Close()
return
}
}
}()
@ -199,7 +211,7 @@ func (c *Client) startup() error {
return err
}
func (c *Client) publishMessageFromBuffer(s *bufferMessage) {
func (c *Client) publishMessageFromBuffer(ctx context.Context, s *bufferMessage) {
origRes := s.Message
origRes.Compress = true
binRes, pErr := origRes.Pack()
@ -213,61 +225,61 @@ func (c *Client) publishMessageFromBuffer(s *bufferMessage) {
})
if mErr == nil {
c.client.Publish(c.ctx, SyncChannelName, binMsg)
c.client.Publish(ctx, SyncChannelName, binMsg)
}
c.client.Set(c.ctx,
c.client.Set(ctx,
prefixKey(s.Key),
binRes,
c.getTTL(origRes))
}
}
func (c *Client) processReceivedMessage(msg *redis.Message) {
func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) {
var rm redisMessage
err := json.Unmarshal([]byte(msg.Payload), &rm)
if err == nil {
// message was sent from a different blocky instance
if !bytes.Equal(rm.Client, c.id) {
switch rm.Type {
case messageTypeCache:
var cm *CacheMessage
if err := json.Unmarshal([]byte(msg.Payload), &rm); err != nil {
c.l.Error("Processing error: ", err)
cm, err = convertMessage(&rm, 0)
if err == nil {
c.CacheChannel <- cm
}
case messageTypeEnable:
err = c.processEnabledMessage(&rm)
default:
c.l.Warn("Unknown message type: ", rm.Type)
return
}
// message was sent from a different blocky instance
if !bytes.Equal(rm.Client, c.id) {
switch rm.Type {
case messageTypeCache:
var cm *CacheMessage
cm, err := convertMessage(&rm, 0)
if err != nil {
c.l.Error("Processing CacheMessage error: ", err)
return
}
util.CtxSend(ctx, c.CacheChannel, cm)
case messageTypeEnable:
var msg EnabledMessage
if err := json.Unmarshal(rm.Message, &msg); err != nil {
c.l.Error("Processing EnabledMessage error: ", err)
return
}
util.CtxSend(ctx, c.EnabledChannel, &msg)
default:
c.l.Warn("Unknown message type: ", rm.Type)
}
}
if err != nil {
c.l.Error("Processing error: ", err)
}
}
func (c *Client) processEnabledMessage(redisMsg *redisMessage) error {
var msg EnabledMessage
err := json.Unmarshal(redisMsg.Message, &msg)
if err == nil {
c.EnabledChannel <- &msg
}
return err
}
// getResponse returns model.Response for a key
func (c *Client) getResponse(key string) (*CacheMessage, error) {
resp, err := c.client.Get(c.ctx, key).Result()
func (c *Client) getResponse(ctx context.Context, key string) (*CacheMessage, error) {
resp, err := c.client.Get(ctx, key).Result()
if err == nil {
var ttl time.Duration
ttl, err = c.client.TTL(c.ctx, key).Result()
ttl, err = c.client.TTL(ctx, key).Result()
if err == nil {
var result *CacheMessage

View File

@ -1,6 +1,7 @@
package redis
import (
"context"
"encoding/json"
"time"
@ -18,76 +19,75 @@ const (
exampleComKey = CacheStorePrefix + "example.com"
)
var (
redisServer *miniredis.Miniredis
redisClient *Client
redisConfig *config.RedisConfig
err error
)
var _ = Describe("Redis client", func() {
var (
redisConfig *config.Redis
redisClient *Client
err error
)
BeforeEach(func() {
redisServer, err = miniredis.Run()
Expect(err).Should(Succeed())
DeferCleanup(redisServer.Close)
var rcfg config.RedisConfig
err = defaults.Set(&rcfg)
Expect(err).Should(Succeed())
rcfg.Address = redisServer.Addr()
var rcfg config.Redis
Expect(defaults.Set(&rcfg)).Should(Succeed())
redisConfig = &rcfg
redisClient, err = New(redisConfig)
Expect(err).Should(Succeed())
Expect(redisClient).ShouldNot(BeNil())
})
Describe("Client creation", func() {
When("redis configuration has no address", func() {
It("should return nil without error", func() {
var rcfg config.RedisConfig
err = defaults.Set(&rcfg)
Expect(err).Should(Succeed())
Expect(New(&rcfg)).Should(BeNil())
It("should return nil without error", func(ctx context.Context) {
Expect(New(ctx, redisConfig)).Should(BeNil())
})
})
When("redis configuration has invalid address", func() {
It("should fail with error", func() {
var rcfg config.RedisConfig
err = defaults.Set(&rcfg)
Expect(err).Should(Succeed())
rcfg.Address = "127.0.0.1:0"
_, err = New(&rcfg)
BeforeEach(func() {
redisConfig.Address = "127.0.0.1:0"
})
It("should fail with error", func(ctx context.Context) {
_, err = New(ctx, redisConfig)
Expect(err).Should(HaveOccurred())
})
})
When("sentinel is enabled without servers", func() {
BeforeEach(func() {
redisConfig.Address = "test"
redisConfig.SentinelAddresses = []string{"127.0.0.1:0"}
})
It("should fail with error", func(ctx context.Context) {
_, err = New(ctx, redisConfig)
Expect(err).Should(HaveOccurred())
})
})
When("redis configuration has invalid password", func() {
It("should fail with error", func() {
var rcfg config.RedisConfig
err = defaults.Set(&rcfg)
Expect(err).Should(Succeed())
rcfg.Address = redisServer.Addr()
rcfg.Password = "wrong"
_, err = New(&rcfg)
BeforeEach(func() {
setupRedisServer(redisConfig)
redisConfig.Password = "wrong"
})
It("should fail with error", func(ctx context.Context) {
_, err = New(ctx, redisConfig)
Expect(err).Should(HaveOccurred())
})
})
})
Describe("Publish message", func() {
var redisServer *miniredis.Miniredis
BeforeEach(func() {
redisServer = setupRedisServer(redisConfig)
})
When("Redis client publishes 'cache' message", func() {
It("One new entry with TTL > 0 should be persisted in the database", func() {
It("One new entry with TTL > 0 should be persisted in the database", func(ctx context.Context) {
redisClient, err = New(ctx, redisConfig)
Expect(err).Should(Succeed())
By("Database is empty", func() {
Eventually(func() []string {
return redisServer.DB(redisConfig.Database).Keys()
@ -112,7 +112,10 @@ var _ = Describe("Redis client", func() {
})
})
It("One new entry with default TTL should be persisted in the database", func() {
It("One new entry with default TTL should be persisted in the database", func(ctx context.Context) {
redisClient, err = New(ctx, redisConfig)
Expect(err).Should(Succeed())
By("Database is empty", func() {
Eventually(func() []string {
return redisServer.DB(redisConfig.Database).Keys()
@ -138,20 +141,30 @@ var _ = Describe("Redis client", func() {
})
})
When("Redis client publishes 'enabled' message", func() {
It("should propagate the message over redis", func() {
redisClient.PublishEnabled(&EnabledMessage{
It("should propagate the message over redis", func(ctx context.Context) {
redisClient, err = New(ctx, redisConfig)
Expect(err).Should(Succeed())
redisClient.PublishEnabled(ctx, &EnabledMessage{
State: true,
})
Eventually(func() map[string]int {
return redisServer.PubSubNumSub(SyncChannelName)
}).Should(HaveLen(1))
})
}, SpecTimeout(time.Second*6))
})
})
Describe("Receive message", func() {
var redisServer *miniredis.Miniredis
BeforeEach(func() {
redisServer = setupRedisServer(redisConfig)
})
When("'enabled' message is received", func() {
It("should propagate the message over the channel", func() {
It("should propagate the message over the channel", func(ctx context.Context) {
redisClient, err = New(ctx, redisConfig)
Expect(err).Should(Succeed())
var binState []byte
binState, err = json.Marshal(EnabledMessage{State: true})
Expect(err).Should(Succeed())
@ -179,7 +192,10 @@ var _ = Describe("Redis client", func() {
})
})
When("'cache' message is received", func() {
It("should propagate the message over the channel", func() {
It("should propagate the message over the channel", func(ctx context.Context) {
redisClient, err = New(ctx, redisConfig)
Expect(err).Should(Succeed())
res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
Expect(err).Should(Succeed())
@ -209,10 +225,13 @@ var _ = Describe("Redis client", func() {
Eventually(func() chan *CacheMessage {
return redisClient.CacheChannel
}).Should(HaveLen(lenE + 1))
})
}, SpecTimeout(time.Second*6))
})
When("wrong data is received", func() {
It("should not propagate the message over the channel if data is wrong", func() {
It("should not propagate the message over the channel if data is wrong", func(ctx context.Context) {
redisClient, err = New(ctx, redisConfig)
Expect(err).Should(Succeed())
var id []byte
id, err = uuid.New().MarshalBinary()
Expect(err).Should(Succeed())
@ -239,8 +258,11 @@ var _ = Describe("Redis client", func() {
Eventually(func() chan *CacheMessage {
return redisClient.CacheChannel
}).Should(HaveLen(lenC))
})
It("should not propagate the message over the channel if type is wrong", func() {
}, SpecTimeout(time.Second*6))
It("should not propagate the message over the channel if type is wrong", func(ctx context.Context) {
redisClient, err = New(ctx, redisConfig)
Expect(err).Should(Succeed())
var id []byte
id, err = uuid.New().MarshalBinary()
Expect(err).Should(Succeed())
@ -269,13 +291,20 @@ var _ = Describe("Redis client", func() {
Eventually(func() chan *CacheMessage {
return redisClient.CacheChannel
}).Should(HaveLen(lenC))
})
}, SpecTimeout(time.Second*6))
})
})
Describe("Read the redis cache and publish it to the channel", func() {
var redisServer *miniredis.Miniredis
BeforeEach(func() {
redisServer = setupRedisServer(redisConfig)
})
When("GetRedisCache is called with valid database entries", func() {
It("Should read data from Redis and propagate it via cache channel", func() {
It("Should read data from Redis and propagate it via cache channel", func(ctx context.Context) {
redisClient, err = New(ctx, redisConfig)
Expect(err).Should(Succeed())
By("Database is empty", func() {
Eventually(func() []string {
return redisServer.DB(redisConfig.Database).Keys()
@ -299,18 +328,30 @@ var _ = Describe("Redis client", func() {
})
By("call GetRedisCache - It should read one entry from redis and propagate it via channel", func() {
redisClient.GetRedisCache()
redisClient.GetRedisCache(ctx)
Eventually(redisClient.CacheChannel).Should(HaveLen(1))
})
})
}, SpecTimeout(time.Second*4))
})
When("GetRedisCache is called and database contains not valid entry", func() {
It("Should do nothing (only log error)", func() {
It("Should do nothing (only log error)", func(ctx context.Context) {
redisClient, err = New(ctx, redisConfig)
Expect(err).Should(Succeed())
Expect(redisServer.DB(redisConfig.Database).Set(CacheStorePrefix+"test", "test")).Should(Succeed())
redisClient.GetRedisCache()
redisClient.GetRedisCache(ctx)
Consistently(redisClient.CacheChannel).Should(BeEmpty())
})
}, SpecTimeout(time.Second*2))
})
})
})
func setupRedisServer(cfg *config.Redis) *miniredis.Miniredis {
redisServer, err := miniredis.Run()
Expect(err).Should(Succeed())
DeferCleanup(redisServer.Close)
cfg.Address = redisServer.Addr()
return redisServer
}

View File

@ -183,7 +183,7 @@ func (r *BlockingResolver) redisSubscriber(ctx context.Context) {
if em.State {
r.internalEnableBlocking()
} else {
err := r.internalDisableBlocking(em.Duration, em.Groups)
err := r.internalDisableBlocking(ctx, em.Duration, em.Groups)
if err != nil {
r.log().Warn("Blocking couldn't be disabled:", err)
}
@ -216,11 +216,11 @@ func (r *BlockingResolver) retrieveAllBlockingGroups() []string {
}
// EnableBlocking enables the blocking against the blacklists
func (r *BlockingResolver) EnableBlocking() {
func (r *BlockingResolver) EnableBlocking(ctx context.Context) {
r.internalEnableBlocking()
if r.redisClient != nil {
r.redisClient.PublishEnabled(&redis.EnabledMessage{State: true})
r.redisClient.PublishEnabled(ctx, &redis.EnabledMessage{State: true})
}
}
@ -236,10 +236,10 @@ func (r *BlockingResolver) internalEnableBlocking() {
}
// DisableBlocking deactivates the blocking for a particular duration (or forever if 0).
func (r *BlockingResolver) DisableBlocking(duration time.Duration, disableGroups []string) error {
err := r.internalDisableBlocking(duration, disableGroups)
func (r *BlockingResolver) DisableBlocking(ctx context.Context, duration time.Duration, disableGroups []string) error {
err := r.internalDisableBlocking(ctx, duration, disableGroups)
if err == nil && r.redisClient != nil {
r.redisClient.PublishEnabled(&redis.EnabledMessage{
r.redisClient.PublishEnabled(ctx, &redis.EnabledMessage{
State: false,
Duration: duration,
Groups: disableGroups,
@ -249,7 +249,9 @@ func (r *BlockingResolver) DisableBlocking(duration time.Duration, disableGroups
return err
}
func (r *BlockingResolver) internalDisableBlocking(duration time.Duration, disableGroups []string) error {
func (r *BlockingResolver) internalDisableBlocking(ctx context.Context, duration time.Duration,
disableGroups []string,
) error {
s := r.status
s.lock.Lock()
defer s.lock.Unlock()
@ -280,7 +282,7 @@ func (r *BlockingResolver) internalDisableBlocking(duration time.Duration, disab
log.Log().Infof("disable blocking for %s for group(s) '%s'", duration,
log.EscapeInput(strings.Join(s.disabledGroups, "; ")))
s.enableTimer = time.AfterFunc(duration, func() {
r.EnableBlocking()
r.EnableBlocking(ctx)
log.Log().Info("blocking enabled again")
})
}

View File

@ -842,7 +842,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
By("Calling Rest API to deactivate all groups", func() {
err := sut.DisableBlocking(0, []string{})
err := sut.DisableBlocking(context.TODO(), 0, []string{})
Expect(err).Should(Succeed())
})
@ -875,7 +875,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
By("Calling Rest API to deactivate only defaultGroup", func() {
err := sut.DisableBlocking(0, []string{"defaultGroup"})
err := sut.DisableBlocking(context.TODO(), 0, []string{"defaultGroup"})
Expect(err).Should(Succeed())
})
@ -935,7 +935,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
enabled <- state
})
Expect(err).Should(Succeed())
err = sut.DisableBlocking(500*time.Millisecond, []string{})
err = sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{})
Expect(err).Should(Succeed())
Eventually(enabled, "1s").Should(Receive(BeFalse()))
})
@ -1025,7 +1025,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
enabled <- false
})
Expect(err).Should(Succeed())
err = sut.DisableBlocking(500*time.Millisecond, []string{"group1"})
err = sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{"group1"})
Expect(err).Should(Succeed())
Eventually(enabled, "1s").Should(Receive(BeFalse()))
})
@ -1086,7 +1086,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Disable blocking is called with wrong group name", func() {
It("should fail", func() {
err := sut.DisableBlocking(500*time.Millisecond, []string{"unknownGroupName"})
err := sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{"unknownGroupName"})
Expect(err).Should(HaveOccurred())
})
})
@ -1094,7 +1094,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Blocking status is called", func() {
It("should return correct status", func() {
By("enable blocking via API", func() {
sut.EnableBlocking()
sut.EnableBlocking(context.TODO())
})
By("Query blocking status via API should return 'enabled'", func() {
@ -1103,7 +1103,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
By("disable blocking via API", func() {
err := sut.DisableBlocking(500*time.Millisecond, []string{})
err := sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{})
Expect(err).Should(Succeed())
})
@ -1149,12 +1149,12 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
Expect(err).Should(Succeed())
var rcfg config.RedisConfig
var rcfg config.Redis
err = defaults.Set(&rcfg)
Expect(err).Should(Succeed())
rcfg.Address = redisServer.Addr()
redisClient, err = redis.New(&rcfg)
redisClient, err = redis.New(context.TODO(), &rcfg)
Expect(err).Should(Succeed())
Expect(redisClient).ShouldNot(BeNil())
@ -1171,7 +1171,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
When("disable", func() {
It("should return disable", func() {
sut.EnableBlocking()
sut.EnableBlocking(context.TODO())
redisMockMsg := &redis.EnabledMessage{
State: false,
@ -1185,7 +1185,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
When("disable", func() {
It("should return disable", func() {
sut.EnableBlocking()
sut.EnableBlocking(context.TODO())
redisMockMsg := &redis.EnabledMessage{
State: false,
Groups: []string{"unknown"},
@ -1199,7 +1199,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
When("enable", func() {
It("should return enable", func() {
err = sut.DisableBlocking(time.Hour, []string{})
err = sut.DisableBlocking(context.TODO(), time.Hour, []string{})
Expect(err).Should(Succeed())
redisMockMsg := &redis.EnabledMessage{

View File

@ -60,7 +60,7 @@ func newCachingResolver(ctx context.Context,
if c.redisClient != nil {
go c.redisSubscriber(ctx)
c.redisClient.GetRedisCache()
c.redisClient.GetRedisCache(ctx)
}
return c

View File

@ -733,7 +733,7 @@ var _ = Describe("CachingResolver", func() {
var (
redisServer *miniredis.Miniredis
redisClient *redis.Client
redisConfig *config.RedisConfig
redisConfig *config.Redis
err error
)
BeforeEach(func() {
@ -741,14 +741,14 @@ var _ = Describe("CachingResolver", func() {
Expect(err).Should(Succeed())
var rcfg config.RedisConfig
var rcfg config.Redis
err = defaults.Set(&rcfg)
Expect(err).Should(Succeed())
rcfg.Address = redisServer.Addr()
redisConfig = &rcfg
redisClient, err = redis.New(redisConfig)
redisClient, err = redis.New(context.TODO(), redisConfig)
Expect(err).Should(Succeed())
Expect(redisClient).ShouldNot(BeNil())

View File

@ -452,14 +452,18 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
Describe("Weighted random on resolver selection", func() {
When("4 upstream resolvers are defined", func() {
var (
mockUpstream1 *MockUDPUpstreamServer
mockUpstream2 *MockUDPUpstreamServer
)
BeforeEach(func() {
withError1 := config.Upstream{Host: "wrong1"}
withError2 := config.Upstream{Host: "wrong2"}
mockUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
mockUpstream1 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream1.Close)
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
mockUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream2.Close)
upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}

View File

@ -31,6 +31,10 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
ctx context.Context
cancelFn context.CancelFunc
timeout = 2 * time.Second
testUpstream1 *MockUDPUpstreamServer
testUpstream2 *MockUDPUpstreamServer
)
Describe("Type", func() {
@ -58,6 +62,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
upstreamsCfg.StartVerify = sutVerify
sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams)
sutConfig.Timeout = config.Duration(timeout)
sut, err = NewStrictResolver(ctx, sutConfig, bootstrap)
})
@ -143,10 +148,10 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
When("Both are responding", func() {
When("they respond in time", func() {
BeforeEach(func() {
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
testUpstream1 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(testUpstream1.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
testUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
DeferCleanup(testUpstream2.Close)
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
@ -165,8 +170,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
When("first upstream exceeds upstreamTimeout", func() {
BeforeEach(func() {
timeout := sut.cfg.Timeout.ToDuration()
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
testUpstream1 = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
time.Sleep(2 * timeout)
@ -176,7 +180,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
DeferCleanup(testUpstream1.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
testUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
DeferCleanup(testUpstream2.Close)
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
@ -193,9 +197,8 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
})
When("all upstreams exceed upsteamTimeout", func() {
BeforeEach(func() {
timeout := sut.cfg.Timeout.ToDuration()
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
JustBeforeEach(func() {
testUpstream1 = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
time.Sleep(2 * timeout)
@ -205,7 +208,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
DeferCleanup(testUpstream1.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
testUpstream2 = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2")
time.Sleep(2 * timeout)
@ -225,12 +228,10 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
When("Only second is working", func() {
BeforeEach(func() {
testUpstream1 := config.Upstream{Host: "wrong"}
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
testUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
DeferCleanup(testUpstream2.Close)
upstreams = []config.Upstream{testUpstream1, testUpstream2.Start()}
upstreams = []config.Upstream{{Host: "wrong"}, testUpstream2.Start()}
})
It("Should use result from second one", func() {
request := newRequest("example.com.", A)

View File

@ -135,9 +135,12 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return nil, err
}
redisClient, redisErr := redis.New(&cfg.Redis)
if redisErr != nil && cfg.Redis.Required {
return nil, redisErr
var redisClient *redis.Client
if cfg.Redis.IsEnabled() {
redisClient, err = redis.New(ctx, &cfg.Redis)
if err != nil && cfg.Redis.Required {
return nil, err
}
}
queryResolver, queryError := createQueryResolver(ctx, cfg, bootstrap, redisClient)
@ -423,6 +426,11 @@ func (s *Server) registerDNSHandlers() {
func (s *Server) printConfiguration() {
logger().Info("current configuration:")
if s.cfg.Redis.IsEnabled() {
logger().Info("Redis:")
log.WithIndent(logger(), " ", s.cfg.Redis.LogConfig)
}
resolver.ForEach(s.queryResolver, func(res resolver.Resolver) {
resolver.LogResolverConfig(res, logger())
})

29
util/context.go Normal file
View File

@ -0,0 +1,29 @@
package util
import "context"
// CtxSend sends a value to a channel while the context isn't done.
// If the message is sent, it returns true.
// If the context is done or the channel is closed, it returns false.
func CtxSend[T any](ctx context.Context, ch chan T, val T) (ok bool) {
if ctx == nil || ch == nil || ctx.Err() != nil {
ok = false
return
}
defer func() {
if err := recover(); err != nil {
ok = false
}
}()
select {
case <-ctx.Done():
ok = false
case ch <- val:
ok = true
}
return
}

125
util/context_test.go Normal file
View File

@ -0,0 +1,125 @@
package util
import (
"context"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const (
testMessage = 1
)
var _ = Describe("Context utils", func() {
Describe("CtxSend", func() {
var ch chan int
BeforeEach(func() {
ch = make(chan int, 1)
})
AfterEach(func() {
ch = nil
})
When("channel is not closed", func() {
It("should send value to channel", func(ctx context.Context) {
go startReader(ctx, ch)
Expect(CtxSend(ctx, ch, testMessage)).Should(BeTrue())
}, SpecTimeout(time.Second))
})
When("channel is closed", func() {
It("should return false", func(ctx context.Context) {
go startReader(ctx, ch)
close(ch)
Expect(CtxSend(ctx, ch, testMessage)).Should(BeFalse())
}, SpecTimeout(time.Second))
})
When("channel is nil", func() {
It("should return false", func(ctx context.Context) {
Expect(CtxSend(ctx, nil, testMessage)).Should(BeFalse())
}, SpecTimeout(time.Second))
})
When("channel is full", func() {
It("should wait", func(ctx context.Context) {
ch <- testMessage
go func(ctx context.Context, ch chan int) {
timer := time.NewTimer(time.Millisecond * 200)
select {
case <-timer.C:
startReader(ctx, ch)
case <-ctx.Done():
return
}
}(ctx, ch)
Expect(CtxSend(ctx, ch, testMessage)).Should(BeTrue())
}, SpecTimeout(time.Second))
})
When("context is done", func() {
It("should return false", func(ctx context.Context) {
go startReader(ctx, ch)
cCtx, cancel := context.WithCancel(ctx)
cancel()
<-cCtx.Done()
Expect(CtxSend(cCtx, ch, testMessage)).Should(BeFalse())
}, SpecTimeout(time.Second))
})
When("context is terminated", func() {
It("should return false", func(ctx context.Context) {
ch <- testMessage
cCtx, cancel := context.WithCancel(ctx)
go func(ctx context.Context, cancel context.CancelFunc) {
timer := time.NewTimer(time.Millisecond * 200)
select {
case <-timer.C:
cancel()
case <-ctx.Done():
return
}
}(ctx, cancel)
Expect(CtxSend(cCtx, ch, testMessage)).Should(BeFalse())
}, SpecTimeout(time.Second))
})
When("context is nil", func() {
It("should return false", func(ctx context.Context) {
Expect(CtxSend(nil, ch, testMessage)).Should(BeFalse())
}, SpecTimeout(time.Second))
})
When("context and channel are nil", func() {
It("should return false", func(ctx context.Context) {
Expect(CtxSend(nil, nil, testMessage)).Should(BeFalse())
}, SpecTimeout(time.Second))
})
When("context is done and channel is closed", func() {
It("should return false", func(ctx context.Context) {
go startReader(ctx, ch)
ctx, cancel := context.WithCancel(ctx)
cancel()
close(ch)
Expect(CtxSend(ctx, ch, testMessage)).Should(BeFalse())
}, SpecTimeout(time.Second))
})
})
})
func startReader(ctx context.Context, ch <-chan int) {
for {
select {
case <-ctx.Done():
return
case i, ok := <-ch:
if ok {
Expect(i).Should(Equal(testMessage))
}
}
}
}