mirror of https://github.com/0xERR0R/blocky.git
Refactoring Redis (#1271)
* RedisConfig -> Redis * moved redis config to seperate file * bugfix in config test during parallel processing * implement config.Configurable in Redis config * use Context in GetRedisCache * use Context in New * caching resolver test fix * use Context in PublishEnabled * use Context in getResponse * remove ctx field * bugfix in api interface test * propperly close channels * set ruler for go files from 80 to 111 * line break because function length is to long * only execute redis.New if it is enabled in config * stabilized flaky tests * Update config/redis.go Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> * Update config/redis_test.go Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> * Update config/redis_test.go Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> * Update config/redis_test.go Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> * Update config/redis.go Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> * Update config/redis_test.go Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> * fix ruler * redis test refactoring * vscode setting cleanup * removed else if chain * Update redis_test.go * context race fix * test fail on missing seintinel servers * cleanup context usage * cleanup2 * context fixes * added context util * disabled nil context rule for tests * copy paste error ctxSend -> CtxSend * use util.CtxSend * fixed comment * fixed flaky test * failsafe and tests --------- Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>
This commit is contained in:
parent
15bd383460
commit
fda2dbe9df
|
@ -31,19 +31,13 @@
|
|||
"GitHub.vscode-github-actions"
|
||||
],
|
||||
"settings": {
|
||||
"go.lintFlags": ["--config=${containerWorkspaceFolder}/.golangci.yml"],
|
||||
"go.lintFlags": [
|
||||
"--config=${containerWorkspaceFolder}/.golangci.yml",
|
||||
"--fast"
|
||||
],
|
||||
"go.alternateTools": {
|
||||
"go-langserver": "gopls"
|
||||
},
|
||||
"[go]": {
|
||||
"editor.defaultFormatter": "golang.go"
|
||||
},
|
||||
"[json][jsonc][github-actions-workflow]": {
|
||||
"editor.defaultFormatter": "esbenp.prettier-vscode"
|
||||
},
|
||||
"[markdown]": {
|
||||
"editor.defaultFormatter": "yzhang.markdown-all-in-one"
|
||||
},
|
||||
"markiscodecoverage.searchCriteria": "**/*.lcov",
|
||||
"runItOn": {
|
||||
"commands": [
|
||||
|
@ -52,6 +46,15 @@
|
|||
"cmd": "${workspaceRoot}/.devcontainer/scripts/runItOnGo.sh ${fileDirname} ${workspaceRoot}"
|
||||
}
|
||||
]
|
||||
},
|
||||
"[go]": {
|
||||
"editor.defaultFormatter": "golang.go"
|
||||
},
|
||||
"[json][jsonc][github-actions-workflow]": {
|
||||
"editor.defaultFormatter": "esbenp.prettier-vscode"
|
||||
},
|
||||
"[markdown]": {
|
||||
"editor.defaultFormatter": "yzhang.markdown-all-in-one"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -98,3 +98,7 @@ issues:
|
|||
- gochecknoinits
|
||||
- gochecknoglobals
|
||||
- gosec
|
||||
- path: _test\.go
|
||||
linters:
|
||||
- staticcheck
|
||||
text: "SA1012:"
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
"source.organizeImports": true,
|
||||
"source.fixAll": true
|
||||
},
|
||||
"editor.rulers": [120],
|
||||
"go.showWelcome": false,
|
||||
"go.survey.prompt": false,
|
||||
"go.useLanguageServer": true,
|
||||
|
|
|
@ -29,8 +29,8 @@ type BlockingStatus struct {
|
|||
|
||||
// BlockingControl interface to control the blocking status
|
||||
type BlockingControl interface {
|
||||
EnableBlocking()
|
||||
DisableBlocking(duration time.Duration, disableGroups []string) error
|
||||
EnableBlocking(ctx context.Context)
|
||||
DisableBlocking(ctx context.Context, duration time.Duration, disableGroups []string) error
|
||||
BlockingStatus() BlockingStatus
|
||||
}
|
||||
|
||||
|
@ -71,7 +71,7 @@ func NewOpenAPIInterfaceImpl(control BlockingControl,
|
|||
}
|
||||
}
|
||||
|
||||
func (i *OpenAPIInterfaceImpl) DisableBlocking(_ context.Context,
|
||||
func (i *OpenAPIInterfaceImpl) DisableBlocking(ctx context.Context,
|
||||
request DisableBlockingRequestObject,
|
||||
) (DisableBlockingResponseObject, error) {
|
||||
var (
|
||||
|
@ -91,7 +91,7 @@ func (i *OpenAPIInterfaceImpl) DisableBlocking(_ context.Context,
|
|||
groups = strings.Split(*request.Params.Groups, ",")
|
||||
}
|
||||
|
||||
err = i.control.DisableBlocking(duration, groups)
|
||||
err = i.control.DisableBlocking(ctx, duration, groups)
|
||||
|
||||
if err != nil {
|
||||
return DisableBlocking400TextResponse(log.EscapeInput(err.Error())), nil
|
||||
|
@ -100,9 +100,9 @@ func (i *OpenAPIInterfaceImpl) DisableBlocking(_ context.Context,
|
|||
return DisableBlocking200Response{}, nil
|
||||
}
|
||||
|
||||
func (i *OpenAPIInterfaceImpl) EnableBlocking(_ context.Context, _ EnableBlockingRequestObject,
|
||||
func (i *OpenAPIInterfaceImpl) EnableBlocking(ctx context.Context, _ EnableBlockingRequestObject,
|
||||
) (EnableBlockingResponseObject, error) {
|
||||
i.control.EnableBlocking()
|
||||
i.control.EnableBlocking(ctx)
|
||||
|
||||
return EnableBlocking200Response{}, nil
|
||||
}
|
||||
|
|
|
@ -37,11 +37,11 @@ func (m *ListRefreshMock) RefreshLists() error {
|
|||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *BlockingControlMock) EnableBlocking() {
|
||||
func (m *BlockingControlMock) EnableBlocking(_ context.Context) {
|
||||
_ = m.Called()
|
||||
}
|
||||
|
||||
func (m *BlockingControlMock) DisableBlocking(t time.Duration, g []string) error {
|
||||
func (m *BlockingControlMock) DisableBlocking(_ context.Context, t time.Duration, g []string) error {
|
||||
args := m.Called(t, g)
|
||||
|
||||
return args.Error(0)
|
||||
|
|
|
@ -219,7 +219,7 @@ type Config struct {
|
|||
Caching CachingConfig `yaml:"caching"`
|
||||
QueryLog QueryLogConfig `yaml:"queryLog"`
|
||||
Prometheus MetricsConfig `yaml:"prometheus"`
|
||||
Redis RedisConfig `yaml:"redis"`
|
||||
Redis Redis `yaml:"redis"`
|
||||
Log log.Config `yaml:"log"`
|
||||
Ports PortsConfig `yaml:"ports"`
|
||||
MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"`
|
||||
|
@ -280,20 +280,6 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
// RedisConfig configuration for the redis connection
|
||||
type RedisConfig struct {
|
||||
Address string `yaml:"address"`
|
||||
Username string `yaml:"username" default:""`
|
||||
Password string `yaml:"password" default:""`
|
||||
Database int `yaml:"database" default:"0"`
|
||||
Required bool `yaml:"required" default:"false"`
|
||||
ConnectionAttempts int `yaml:"connectionAttempts" default:"3"`
|
||||
ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"`
|
||||
SentinelUsername string `yaml:"sentinelUsername" default:""`
|
||||
SentinelPassword string `yaml:"sentinelPassword" default:""`
|
||||
SentinelAddresses []string `yaml:"sentinelAddresses"`
|
||||
}
|
||||
|
||||
type (
|
||||
FQDNOnly = toEnable
|
||||
EDE = toEnable
|
||||
|
|
|
@ -169,7 +169,7 @@ var _ = Describe("Config", func() {
|
|||
err = writeConfigDir(tmpDir)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
_, err := LoadConfig(tmpDir.Path, true)
|
||||
c, err = LoadConfig(tmpDir.Path, true)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
defaultTestFileConfig(c)
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Redis configuration for the redis connection
|
||||
type Redis struct {
|
||||
Address string `yaml:"address"`
|
||||
Username string `yaml:"username" default:""`
|
||||
Password string `yaml:"password" default:""`
|
||||
Database int `yaml:"database" default:"0"`
|
||||
Required bool `yaml:"required" default:"false"`
|
||||
ConnectionAttempts int `yaml:"connectionAttempts" default:"3"`
|
||||
ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"`
|
||||
SentinelUsername string `yaml:"sentinelUsername" default:""`
|
||||
SentinelPassword string `yaml:"sentinelPassword" default:""`
|
||||
SentinelAddresses []string `yaml:"sentinelAddresses"`
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`
|
||||
func (c *Redis) IsEnabled() bool {
|
||||
return c.Address != ""
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`
|
||||
func (c *Redis) LogConfig(logger *logrus.Entry) {
|
||||
if len(c.SentinelAddresses) == 0 {
|
||||
logger.Info("address: ", c.Address)
|
||||
}
|
||||
|
||||
logger.Info("username: ", c.Username)
|
||||
logger.Info("password: ", obfuscatePassword(c.Password))
|
||||
logger.Info("database: ", c.Database)
|
||||
logger.Info("required: ", c.Required)
|
||||
logger.Info("connectionAttempts: ", c.ConnectionAttempts)
|
||||
logger.Info("connectionCooldown: ", c.ConnectionCooldown)
|
||||
|
||||
if len(c.SentinelAddresses) > 0 {
|
||||
logger.Info("sentinel:")
|
||||
logger.Info(" master: ", c.Address)
|
||||
logger.Info(" username: ", c.SentinelUsername)
|
||||
logger.Info(" password: ", obfuscatePassword(c.SentinelPassword))
|
||||
logger.Info(" addresses:")
|
||||
|
||||
for _, addr := range c.SentinelAddresses {
|
||||
logger.Info(" - ", addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// obfuscatePassword replaces all characters of a password except the first and last with *
|
||||
func obfuscatePassword(pass string) string {
|
||||
return strings.Repeat("*", len(pass))
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/creasty/defaults"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Redis", func() {
|
||||
var (
|
||||
c Redis
|
||||
err error
|
||||
)
|
||||
|
||||
suiteBeforeEach()
|
||||
|
||||
BeforeEach(func() {
|
||||
err = defaults.Set(&c)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
When("all fields are default", func() {
|
||||
It("should be disabled", func() {
|
||||
Expect(c.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
When("Address is set", func() {
|
||||
BeforeEach(func() {
|
||||
c.Address = "localhost:6379"
|
||||
})
|
||||
|
||||
It("should be enabled", func() {
|
||||
Expect(c.IsEnabled()).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
BeforeEach(func() {
|
||||
logger, hook = log.NewMockEntry()
|
||||
})
|
||||
|
||||
When("all fields are default", func() {
|
||||
It("should log default values", func() {
|
||||
c.LogConfig(logger)
|
||||
|
||||
Expect(hook.Messages).Should(
|
||||
SatisfyAll(ContainElement(ContainSubstring("address: ")),
|
||||
ContainElement(ContainSubstring("username: ")),
|
||||
ContainElement(ContainSubstring("password: ")),
|
||||
ContainElement(ContainSubstring("database: ")),
|
||||
ContainElement(ContainSubstring("required: ")),
|
||||
ContainElement(ContainSubstring("connectionAttempts: ")),
|
||||
ContainElement(ContainSubstring("connectionCooldown: "))))
|
||||
})
|
||||
})
|
||||
|
||||
When("Address is set", func() {
|
||||
BeforeEach(func() {
|
||||
c.Address = "localhost:6379"
|
||||
})
|
||||
|
||||
It("should log address", func() {
|
||||
c.LogConfig(logger)
|
||||
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("address: localhost:6379")))
|
||||
})
|
||||
})
|
||||
|
||||
When("SentinelAddresses is set", func() {
|
||||
BeforeEach(func() {
|
||||
c.SentinelAddresses = []string{"localhost:26379", "localhost:26380"}
|
||||
})
|
||||
|
||||
It("should log sentinel addresses", func() {
|
||||
c.LogConfig(logger)
|
||||
|
||||
Expect(hook.Messages).Should(
|
||||
SatisfyAll(
|
||||
ContainElement(ContainSubstring("sentinel:")),
|
||||
ContainElement(ContainSubstring(" addresses:")),
|
||||
ContainElement(ContainSubstring(" - localhost:26379")),
|
||||
ContainElement(ContainSubstring(" - localhost:26380"))))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("obfuscatePassword", func() {
|
||||
When("password is empty", func() {
|
||||
It("should return empty string", func() {
|
||||
Expect(obfuscatePassword("")).Should(Equal(""))
|
||||
})
|
||||
})
|
||||
|
||||
When("password is not empty", func() {
|
||||
It("should return obfuscated password", func() {
|
||||
Expect(obfuscatePassword("test123")).Should(Equal("*******"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
128
redis/redis.go
128
redis/redis.go
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/google/uuid"
|
||||
"github.com/miekg/dns"
|
||||
|
@ -56,10 +57,9 @@ type EnabledMessage struct {
|
|||
|
||||
// Client for redis communication
|
||||
type Client struct {
|
||||
config *config.RedisConfig
|
||||
config *config.Redis
|
||||
client *redis.Client
|
||||
l *logrus.Entry
|
||||
ctx context.Context
|
||||
id []byte
|
||||
sendBuffer chan *bufferMessage
|
||||
CacheChannel chan *CacheMessage
|
||||
|
@ -67,15 +67,15 @@ type Client struct {
|
|||
}
|
||||
|
||||
// New creates a new redis client
|
||||
func New(cfg *config.RedisConfig) (*Client, error) {
|
||||
func New(ctx context.Context, cfg *config.Redis) (*Client, error) {
|
||||
// disable redis if no address is provided
|
||||
if cfg == nil || len(cfg.Address) == 0 {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
var rdb *redis.Client
|
||||
var baseClient *redis.Client
|
||||
if len(cfg.SentinelAddresses) > 0 {
|
||||
rdb = redis.NewFailoverClient(&redis.FailoverOptions{
|
||||
baseClient = redis.NewFailoverClient(&redis.FailoverOptions{
|
||||
MasterName: cfg.Address,
|
||||
SentinelUsername: cfg.Username,
|
||||
SentinelPassword: cfg.SentinelPassword,
|
||||
|
@ -87,7 +87,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
|
|||
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
|
||||
})
|
||||
} else {
|
||||
rdb = redis.NewClient(&redis.Options{
|
||||
baseClient = redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Address,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
|
@ -97,7 +97,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
|
|||
})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
rdb := baseClient.WithContext(ctx)
|
||||
|
||||
_, err := rdb.Ping(ctx).Result()
|
||||
if err == nil {
|
||||
|
@ -110,7 +110,6 @@ func New(cfg *config.RedisConfig) (*Client, error) {
|
|||
config: cfg,
|
||||
client: rdb,
|
||||
l: log.PrefixedLog("redis"),
|
||||
ctx: ctx,
|
||||
id: id,
|
||||
sendBuffer: make(chan *bufferMessage, chanCap),
|
||||
CacheChannel: make(chan *CacheMessage, chanCap),
|
||||
|
@ -118,7 +117,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
|
|||
}
|
||||
|
||||
// start channel handling go routine
|
||||
err = res.startup()
|
||||
err = res.startup(ctx)
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
@ -137,7 +136,7 @@ func (c *Client) PublishCache(key string, message *dns.Msg) {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Client) PublishEnabled(state *EnabledMessage) {
|
||||
func (c *Client) PublishEnabled(ctx context.Context, state *EnabledMessage) {
|
||||
binState, sErr := json.Marshal(state)
|
||||
if sErr == nil {
|
||||
binMsg, mErr := json.Marshal(redisMessage{
|
||||
|
@ -147,22 +146,30 @@ func (c *Client) PublishEnabled(state *EnabledMessage) {
|
|||
})
|
||||
|
||||
if mErr == nil {
|
||||
c.client.Publish(c.ctx, SyncChannelName, binMsg)
|
||||
c.client.Publish(ctx, SyncChannelName, binMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRedisCache reads the redis cache and publish it to the channel
|
||||
func (c *Client) GetRedisCache() {
|
||||
func (c *Client) GetRedisCache(ctx context.Context) {
|
||||
c.l.Debug("GetRedisCache")
|
||||
|
||||
go func() {
|
||||
iter := c.client.Scan(c.ctx, 0, prefixKey("*"), 0).Iterator()
|
||||
for iter.Next(c.ctx) {
|
||||
response, err := c.getResponse(iter.Val())
|
||||
iter := c.client.Scan(ctx, 0, prefixKey("*"), 0).Iterator()
|
||||
if err := iter.Err(); err != nil {
|
||||
c.l.Error("GetRedisCache ", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for iter.Next(ctx) {
|
||||
response, err := c.getResponse(ctx, iter.Val())
|
||||
if err == nil {
|
||||
if response != nil {
|
||||
c.CacheChannel <- response
|
||||
if !util.CtxSend(ctx, c.CacheChannel, response) {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
c.l.Error("GetRedisCache ", err)
|
||||
|
@ -172,10 +179,10 @@ func (c *Client) GetRedisCache() {
|
|||
}
|
||||
|
||||
// startup starts a new goroutine for subscription and translation
|
||||
func (c *Client) startup() error {
|
||||
ps := c.client.Subscribe(c.ctx, SyncChannelName)
|
||||
func (c *Client) startup(ctx context.Context) error {
|
||||
ps := c.client.Subscribe(ctx, SyncChannelName)
|
||||
|
||||
_, err := ps.Receive(c.ctx)
|
||||
_, err := ps.Receive(ctx)
|
||||
if err == nil {
|
||||
go func() {
|
||||
for {
|
||||
|
@ -186,11 +193,16 @@ func (c *Client) startup() error {
|
|||
|
||||
if msg != nil && len(msg.Payload) > 0 {
|
||||
// message is not empty
|
||||
c.processReceivedMessage(msg)
|
||||
c.processReceivedMessage(ctx, msg)
|
||||
}
|
||||
// publish message from buffer
|
||||
case s := <-c.sendBuffer:
|
||||
c.publishMessageFromBuffer(s)
|
||||
c.publishMessageFromBuffer(ctx, s)
|
||||
// context is done
|
||||
case <-ctx.Done():
|
||||
c.client.Close()
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -199,7 +211,7 @@ func (c *Client) startup() error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (c *Client) publishMessageFromBuffer(s *bufferMessage) {
|
||||
func (c *Client) publishMessageFromBuffer(ctx context.Context, s *bufferMessage) {
|
||||
origRes := s.Message
|
||||
origRes.Compress = true
|
||||
binRes, pErr := origRes.Pack()
|
||||
|
@ -213,61 +225,61 @@ func (c *Client) publishMessageFromBuffer(s *bufferMessage) {
|
|||
})
|
||||
|
||||
if mErr == nil {
|
||||
c.client.Publish(c.ctx, SyncChannelName, binMsg)
|
||||
c.client.Publish(ctx, SyncChannelName, binMsg)
|
||||
}
|
||||
|
||||
c.client.Set(c.ctx,
|
||||
c.client.Set(ctx,
|
||||
prefixKey(s.Key),
|
||||
binRes,
|
||||
c.getTTL(origRes))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) processReceivedMessage(msg *redis.Message) {
|
||||
func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) {
|
||||
var rm redisMessage
|
||||
|
||||
err := json.Unmarshal([]byte(msg.Payload), &rm)
|
||||
if err == nil {
|
||||
// message was sent from a different blocky instance
|
||||
if !bytes.Equal(rm.Client, c.id) {
|
||||
switch rm.Type {
|
||||
case messageTypeCache:
|
||||
var cm *CacheMessage
|
||||
if err := json.Unmarshal([]byte(msg.Payload), &rm); err != nil {
|
||||
c.l.Error("Processing error: ", err)
|
||||
|
||||
cm, err = convertMessage(&rm, 0)
|
||||
if err == nil {
|
||||
c.CacheChannel <- cm
|
||||
}
|
||||
case messageTypeEnable:
|
||||
err = c.processEnabledMessage(&rm)
|
||||
default:
|
||||
c.l.Warn("Unknown message type: ", rm.Type)
|
||||
return
|
||||
}
|
||||
|
||||
// message was sent from a different blocky instance
|
||||
if !bytes.Equal(rm.Client, c.id) {
|
||||
switch rm.Type {
|
||||
case messageTypeCache:
|
||||
var cm *CacheMessage
|
||||
|
||||
cm, err := convertMessage(&rm, 0)
|
||||
if err != nil {
|
||||
c.l.Error("Processing CacheMessage error: ", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
util.CtxSend(ctx, c.CacheChannel, cm)
|
||||
case messageTypeEnable:
|
||||
var msg EnabledMessage
|
||||
|
||||
if err := json.Unmarshal(rm.Message, &msg); err != nil {
|
||||
c.l.Error("Processing EnabledMessage error: ", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
util.CtxSend(ctx, c.EnabledChannel, &msg)
|
||||
default:
|
||||
c.l.Warn("Unknown message type: ", rm.Type)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.l.Error("Processing error: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) processEnabledMessage(redisMsg *redisMessage) error {
|
||||
var msg EnabledMessage
|
||||
|
||||
err := json.Unmarshal(redisMsg.Message, &msg)
|
||||
if err == nil {
|
||||
c.EnabledChannel <- &msg
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// getResponse returns model.Response for a key
|
||||
func (c *Client) getResponse(key string) (*CacheMessage, error) {
|
||||
resp, err := c.client.Get(c.ctx, key).Result()
|
||||
func (c *Client) getResponse(ctx context.Context, key string) (*CacheMessage, error) {
|
||||
resp, err := c.client.Get(ctx, key).Result()
|
||||
if err == nil {
|
||||
var ttl time.Duration
|
||||
ttl, err = c.client.TTL(c.ctx, key).Result()
|
||||
ttl, err = c.client.TTL(ctx, key).Result()
|
||||
|
||||
if err == nil {
|
||||
var result *CacheMessage
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
|
@ -18,76 +19,75 @@ const (
|
|||
exampleComKey = CacheStorePrefix + "example.com"
|
||||
)
|
||||
|
||||
var (
|
||||
redisServer *miniredis.Miniredis
|
||||
redisClient *Client
|
||||
redisConfig *config.RedisConfig
|
||||
err error
|
||||
)
|
||||
|
||||
var _ = Describe("Redis client", func() {
|
||||
var (
|
||||
redisConfig *config.Redis
|
||||
|
||||
redisClient *Client
|
||||
|
||||
err error
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
redisServer, err = miniredis.Run()
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
DeferCleanup(redisServer.Close)
|
||||
|
||||
var rcfg config.RedisConfig
|
||||
err = defaults.Set(&rcfg)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
rcfg.Address = redisServer.Addr()
|
||||
var rcfg config.Redis
|
||||
Expect(defaults.Set(&rcfg)).Should(Succeed())
|
||||
redisConfig = &rcfg
|
||||
redisClient, err = New(redisConfig)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(redisClient).ShouldNot(BeNil())
|
||||
})
|
||||
|
||||
Describe("Client creation", func() {
|
||||
When("redis configuration has no address", func() {
|
||||
It("should return nil without error", func() {
|
||||
var rcfg config.RedisConfig
|
||||
err = defaults.Set(&rcfg)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
Expect(New(&rcfg)).Should(BeNil())
|
||||
It("should return nil without error", func(ctx context.Context) {
|
||||
Expect(New(ctx, redisConfig)).Should(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
When("redis configuration has invalid address", func() {
|
||||
It("should fail with error", func() {
|
||||
var rcfg config.RedisConfig
|
||||
err = defaults.Set(&rcfg)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
rcfg.Address = "127.0.0.1:0"
|
||||
|
||||
_, err = New(&rcfg)
|
||||
BeforeEach(func() {
|
||||
redisConfig.Address = "127.0.0.1:0"
|
||||
})
|
||||
|
||||
It("should fail with error", func(ctx context.Context) {
|
||||
_, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
When("sentinel is enabled without servers", func() {
|
||||
BeforeEach(func() {
|
||||
redisConfig.Address = "test"
|
||||
redisConfig.SentinelAddresses = []string{"127.0.0.1:0"}
|
||||
})
|
||||
|
||||
It("should fail with error", func(ctx context.Context) {
|
||||
_, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
When("redis configuration has invalid password", func() {
|
||||
It("should fail with error", func() {
|
||||
var rcfg config.RedisConfig
|
||||
err = defaults.Set(&rcfg)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
rcfg.Address = redisServer.Addr()
|
||||
rcfg.Password = "wrong"
|
||||
|
||||
_, err = New(&rcfg)
|
||||
BeforeEach(func() {
|
||||
setupRedisServer(redisConfig)
|
||||
redisConfig.Password = "wrong"
|
||||
})
|
||||
|
||||
It("should fail with error", func(ctx context.Context) {
|
||||
_, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Publish message", func() {
|
||||
var redisServer *miniredis.Miniredis
|
||||
BeforeEach(func() {
|
||||
redisServer = setupRedisServer(redisConfig)
|
||||
})
|
||||
|
||||
When("Redis client publishes 'cache' message", func() {
|
||||
It("One new entry with TTL > 0 should be persisted in the database", func() {
|
||||
It("One new entry with TTL > 0 should be persisted in the database", func(ctx context.Context) {
|
||||
redisClient, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
By("Database is empty", func() {
|
||||
Eventually(func() []string {
|
||||
return redisServer.DB(redisConfig.Database).Keys()
|
||||
|
@ -112,7 +112,10 @@ var _ = Describe("Redis client", func() {
|
|||
})
|
||||
})
|
||||
|
||||
It("One new entry with default TTL should be persisted in the database", func() {
|
||||
It("One new entry with default TTL should be persisted in the database", func(ctx context.Context) {
|
||||
redisClient, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
By("Database is empty", func() {
|
||||
Eventually(func() []string {
|
||||
return redisServer.DB(redisConfig.Database).Keys()
|
||||
|
@ -138,20 +141,30 @@ var _ = Describe("Redis client", func() {
|
|||
})
|
||||
})
|
||||
When("Redis client publishes 'enabled' message", func() {
|
||||
It("should propagate the message over redis", func() {
|
||||
redisClient.PublishEnabled(&EnabledMessage{
|
||||
It("should propagate the message over redis", func(ctx context.Context) {
|
||||
redisClient, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
redisClient.PublishEnabled(ctx, &EnabledMessage{
|
||||
State: true,
|
||||
})
|
||||
Eventually(func() map[string]int {
|
||||
return redisServer.PubSubNumSub(SyncChannelName)
|
||||
}).Should(HaveLen(1))
|
||||
})
|
||||
}, SpecTimeout(time.Second*6))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Receive message", func() {
|
||||
var redisServer *miniredis.Miniredis
|
||||
BeforeEach(func() {
|
||||
redisServer = setupRedisServer(redisConfig)
|
||||
})
|
||||
When("'enabled' message is received", func() {
|
||||
It("should propagate the message over the channel", func() {
|
||||
It("should propagate the message over the channel", func(ctx context.Context) {
|
||||
redisClient, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
var binState []byte
|
||||
binState, err = json.Marshal(EnabledMessage{State: true})
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -179,7 +192,10 @@ var _ = Describe("Redis client", func() {
|
|||
})
|
||||
})
|
||||
When("'cache' message is received", func() {
|
||||
It("should propagate the message over the channel", func() {
|
||||
It("should propagate the message over the channel", func(ctx context.Context) {
|
||||
redisClient, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -209,10 +225,13 @@ var _ = Describe("Redis client", func() {
|
|||
Eventually(func() chan *CacheMessage {
|
||||
return redisClient.CacheChannel
|
||||
}).Should(HaveLen(lenE + 1))
|
||||
})
|
||||
}, SpecTimeout(time.Second*6))
|
||||
})
|
||||
When("wrong data is received", func() {
|
||||
It("should not propagate the message over the channel if data is wrong", func() {
|
||||
It("should not propagate the message over the channel if data is wrong", func(ctx context.Context) {
|
||||
redisClient, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
var id []byte
|
||||
id, err = uuid.New().MarshalBinary()
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -239,8 +258,11 @@ var _ = Describe("Redis client", func() {
|
|||
Eventually(func() chan *CacheMessage {
|
||||
return redisClient.CacheChannel
|
||||
}).Should(HaveLen(lenC))
|
||||
})
|
||||
It("should not propagate the message over the channel if type is wrong", func() {
|
||||
}, SpecTimeout(time.Second*6))
|
||||
It("should not propagate the message over the channel if type is wrong", func(ctx context.Context) {
|
||||
redisClient, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
var id []byte
|
||||
id, err = uuid.New().MarshalBinary()
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -269,13 +291,20 @@ var _ = Describe("Redis client", func() {
|
|||
Eventually(func() chan *CacheMessage {
|
||||
return redisClient.CacheChannel
|
||||
}).Should(HaveLen(lenC))
|
||||
})
|
||||
}, SpecTimeout(time.Second*6))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Read the redis cache and publish it to the channel", func() {
|
||||
var redisServer *miniredis.Miniredis
|
||||
BeforeEach(func() {
|
||||
redisServer = setupRedisServer(redisConfig)
|
||||
})
|
||||
When("GetRedisCache is called with valid database entries", func() {
|
||||
It("Should read data from Redis and propagate it via cache channel", func() {
|
||||
It("Should read data from Redis and propagate it via cache channel", func(ctx context.Context) {
|
||||
redisClient, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
By("Database is empty", func() {
|
||||
Eventually(func() []string {
|
||||
return redisServer.DB(redisConfig.Database).Keys()
|
||||
|
@ -299,18 +328,30 @@ var _ = Describe("Redis client", func() {
|
|||
})
|
||||
|
||||
By("call GetRedisCache - It should read one entry from redis and propagate it via channel", func() {
|
||||
redisClient.GetRedisCache()
|
||||
redisClient.GetRedisCache(ctx)
|
||||
|
||||
Eventually(redisClient.CacheChannel).Should(HaveLen(1))
|
||||
})
|
||||
})
|
||||
}, SpecTimeout(time.Second*4))
|
||||
})
|
||||
When("GetRedisCache is called and database contains not valid entry", func() {
|
||||
It("Should do nothing (only log error)", func() {
|
||||
It("Should do nothing (only log error)", func(ctx context.Context) {
|
||||
redisClient, err = New(ctx, redisConfig)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
Expect(redisServer.DB(redisConfig.Database).Set(CacheStorePrefix+"test", "test")).Should(Succeed())
|
||||
redisClient.GetRedisCache()
|
||||
redisClient.GetRedisCache(ctx)
|
||||
Consistently(redisClient.CacheChannel).Should(BeEmpty())
|
||||
})
|
||||
}, SpecTimeout(time.Second*2))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func setupRedisServer(cfg *config.Redis) *miniredis.Miniredis {
|
||||
redisServer, err := miniredis.Run()
|
||||
Expect(err).Should(Succeed())
|
||||
DeferCleanup(redisServer.Close)
|
||||
cfg.Address = redisServer.Addr()
|
||||
|
||||
return redisServer
|
||||
}
|
||||
|
|
|
@ -183,7 +183,7 @@ func (r *BlockingResolver) redisSubscriber(ctx context.Context) {
|
|||
if em.State {
|
||||
r.internalEnableBlocking()
|
||||
} else {
|
||||
err := r.internalDisableBlocking(em.Duration, em.Groups)
|
||||
err := r.internalDisableBlocking(ctx, em.Duration, em.Groups)
|
||||
if err != nil {
|
||||
r.log().Warn("Blocking couldn't be disabled:", err)
|
||||
}
|
||||
|
@ -216,11 +216,11 @@ func (r *BlockingResolver) retrieveAllBlockingGroups() []string {
|
|||
}
|
||||
|
||||
// EnableBlocking enables the blocking against the blacklists
|
||||
func (r *BlockingResolver) EnableBlocking() {
|
||||
func (r *BlockingResolver) EnableBlocking(ctx context.Context) {
|
||||
r.internalEnableBlocking()
|
||||
|
||||
if r.redisClient != nil {
|
||||
r.redisClient.PublishEnabled(&redis.EnabledMessage{State: true})
|
||||
r.redisClient.PublishEnabled(ctx, &redis.EnabledMessage{State: true})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -236,10 +236,10 @@ func (r *BlockingResolver) internalEnableBlocking() {
|
|||
}
|
||||
|
||||
// DisableBlocking deactivates the blocking for a particular duration (or forever if 0).
|
||||
func (r *BlockingResolver) DisableBlocking(duration time.Duration, disableGroups []string) error {
|
||||
err := r.internalDisableBlocking(duration, disableGroups)
|
||||
func (r *BlockingResolver) DisableBlocking(ctx context.Context, duration time.Duration, disableGroups []string) error {
|
||||
err := r.internalDisableBlocking(ctx, duration, disableGroups)
|
||||
if err == nil && r.redisClient != nil {
|
||||
r.redisClient.PublishEnabled(&redis.EnabledMessage{
|
||||
r.redisClient.PublishEnabled(ctx, &redis.EnabledMessage{
|
||||
State: false,
|
||||
Duration: duration,
|
||||
Groups: disableGroups,
|
||||
|
@ -249,7 +249,9 @@ func (r *BlockingResolver) DisableBlocking(duration time.Duration, disableGroups
|
|||
return err
|
||||
}
|
||||
|
||||
func (r *BlockingResolver) internalDisableBlocking(duration time.Duration, disableGroups []string) error {
|
||||
func (r *BlockingResolver) internalDisableBlocking(ctx context.Context, duration time.Duration,
|
||||
disableGroups []string,
|
||||
) error {
|
||||
s := r.status
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
@ -280,7 +282,7 @@ func (r *BlockingResolver) internalDisableBlocking(duration time.Duration, disab
|
|||
log.Log().Infof("disable blocking for %s for group(s) '%s'", duration,
|
||||
log.EscapeInput(strings.Join(s.disabledGroups, "; ")))
|
||||
s.enableTimer = time.AfterFunc(duration, func() {
|
||||
r.EnableBlocking()
|
||||
r.EnableBlocking(ctx)
|
||||
log.Log().Info("blocking enabled again")
|
||||
})
|
||||
}
|
||||
|
|
|
@ -842,7 +842,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
By("Calling Rest API to deactivate all groups", func() {
|
||||
err := sut.DisableBlocking(0, []string{})
|
||||
err := sut.DisableBlocking(context.TODO(), 0, []string{})
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
|
||||
|
@ -875,7 +875,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
By("Calling Rest API to deactivate only defaultGroup", func() {
|
||||
err := sut.DisableBlocking(0, []string{"defaultGroup"})
|
||||
err := sut.DisableBlocking(context.TODO(), 0, []string{"defaultGroup"})
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
|
||||
|
@ -935,7 +935,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
enabled <- state
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
err = sut.DisableBlocking(500*time.Millisecond, []string{})
|
||||
err = sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{})
|
||||
Expect(err).Should(Succeed())
|
||||
Eventually(enabled, "1s").Should(Receive(BeFalse()))
|
||||
})
|
||||
|
@ -1025,7 +1025,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
enabled <- false
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
err = sut.DisableBlocking(500*time.Millisecond, []string{"group1"})
|
||||
err = sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{"group1"})
|
||||
Expect(err).Should(Succeed())
|
||||
Eventually(enabled, "1s").Should(Receive(BeFalse()))
|
||||
})
|
||||
|
@ -1086,7 +1086,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
When("Disable blocking is called with wrong group name", func() {
|
||||
It("should fail", func() {
|
||||
err := sut.DisableBlocking(500*time.Millisecond, []string{"unknownGroupName"})
|
||||
err := sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{"unknownGroupName"})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
@ -1094,7 +1094,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
When("Blocking status is called", func() {
|
||||
It("should return correct status", func() {
|
||||
By("enable blocking via API", func() {
|
||||
sut.EnableBlocking()
|
||||
sut.EnableBlocking(context.TODO())
|
||||
})
|
||||
|
||||
By("Query blocking status via API should return 'enabled'", func() {
|
||||
|
@ -1103,7 +1103,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
By("disable blocking via API", func() {
|
||||
err := sut.DisableBlocking(500*time.Millisecond, []string{})
|
||||
err := sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{})
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
|
||||
|
@ -1149,12 +1149,12 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
var rcfg config.RedisConfig
|
||||
var rcfg config.Redis
|
||||
err = defaults.Set(&rcfg)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
rcfg.Address = redisServer.Addr()
|
||||
redisClient, err = redis.New(&rcfg)
|
||||
redisClient, err = redis.New(context.TODO(), &rcfg)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(redisClient).ShouldNot(BeNil())
|
||||
|
@ -1171,7 +1171,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
When("disable", func() {
|
||||
It("should return disable", func() {
|
||||
sut.EnableBlocking()
|
||||
sut.EnableBlocking(context.TODO())
|
||||
|
||||
redisMockMsg := &redis.EnabledMessage{
|
||||
State: false,
|
||||
|
@ -1185,7 +1185,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
When("disable", func() {
|
||||
It("should return disable", func() {
|
||||
sut.EnableBlocking()
|
||||
sut.EnableBlocking(context.TODO())
|
||||
redisMockMsg := &redis.EnabledMessage{
|
||||
State: false,
|
||||
Groups: []string{"unknown"},
|
||||
|
@ -1199,7 +1199,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
When("enable", func() {
|
||||
It("should return enable", func() {
|
||||
err = sut.DisableBlocking(time.Hour, []string{})
|
||||
err = sut.DisableBlocking(context.TODO(), time.Hour, []string{})
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
redisMockMsg := &redis.EnabledMessage{
|
||||
|
|
|
@ -60,7 +60,7 @@ func newCachingResolver(ctx context.Context,
|
|||
|
||||
if c.redisClient != nil {
|
||||
go c.redisSubscriber(ctx)
|
||||
c.redisClient.GetRedisCache()
|
||||
c.redisClient.GetRedisCache(ctx)
|
||||
}
|
||||
|
||||
return c
|
||||
|
|
|
@ -733,7 +733,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
var (
|
||||
redisServer *miniredis.Miniredis
|
||||
redisClient *redis.Client
|
||||
redisConfig *config.RedisConfig
|
||||
redisConfig *config.Redis
|
||||
err error
|
||||
)
|
||||
BeforeEach(func() {
|
||||
|
@ -741,14 +741,14 @@ var _ = Describe("CachingResolver", func() {
|
|||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
var rcfg config.RedisConfig
|
||||
var rcfg config.Redis
|
||||
err = defaults.Set(&rcfg)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
rcfg.Address = redisServer.Addr()
|
||||
redisConfig = &rcfg
|
||||
redisClient, err = redis.New(redisConfig)
|
||||
redisClient, err = redis.New(context.TODO(), redisConfig)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(redisClient).ShouldNot(BeNil())
|
||||
|
|
|
@ -452,14 +452,18 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
|
||||
Describe("Weighted random on resolver selection", func() {
|
||||
When("4 upstream resolvers are defined", func() {
|
||||
var (
|
||||
mockUpstream1 *MockUDPUpstreamServer
|
||||
mockUpstream2 *MockUDPUpstreamServer
|
||||
)
|
||||
BeforeEach(func() {
|
||||
withError1 := config.Upstream{Host: "wrong1"}
|
||||
withError2 := config.Upstream{Host: "wrong2"}
|
||||
|
||||
mockUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
mockUpstream1 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(mockUpstream1.Close)
|
||||
|
||||
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
mockUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(mockUpstream2.Close)
|
||||
|
||||
upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}
|
||||
|
|
|
@ -31,6 +31,10 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
timeout = 2 * time.Second
|
||||
|
||||
testUpstream1 *MockUDPUpstreamServer
|
||||
testUpstream2 *MockUDPUpstreamServer
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -58,6 +62,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
upstreamsCfg.StartVerify = sutVerify
|
||||
|
||||
sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams)
|
||||
sutConfig.Timeout = config.Duration(timeout)
|
||||
sut, err = NewStrictResolver(ctx, sutConfig, bootstrap)
|
||||
})
|
||||
|
||||
|
@ -143,10 +148,10 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
When("Both are responding", func() {
|
||||
When("they respond in time", func() {
|
||||
BeforeEach(func() {
|
||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
testUpstream1 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(testUpstream1.Close)
|
||||
|
||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
|
||||
testUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
|
||||
DeferCleanup(testUpstream2.Close)
|
||||
|
||||
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
|
||||
|
@ -165,8 +170,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
})
|
||||
When("first upstream exceeds upstreamTimeout", func() {
|
||||
BeforeEach(func() {
|
||||
timeout := sut.cfg.Timeout.ToDuration()
|
||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
testUpstream1 = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
|
||||
time.Sleep(2 * timeout)
|
||||
|
||||
|
@ -176,7 +180,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
})
|
||||
DeferCleanup(testUpstream1.Close)
|
||||
|
||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
|
||||
testUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
|
||||
DeferCleanup(testUpstream2.Close)
|
||||
|
||||
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
|
||||
|
@ -193,9 +197,8 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
})
|
||||
})
|
||||
When("all upstreams exceed upsteamTimeout", func() {
|
||||
BeforeEach(func() {
|
||||
timeout := sut.cfg.Timeout.ToDuration()
|
||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
JustBeforeEach(func() {
|
||||
testUpstream1 = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
|
||||
time.Sleep(2 * timeout)
|
||||
|
||||
|
@ -205,7 +208,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
})
|
||||
DeferCleanup(testUpstream1.Close)
|
||||
|
||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
testUpstream2 = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2")
|
||||
time.Sleep(2 * timeout)
|
||||
|
||||
|
@ -225,12 +228,10 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
})
|
||||
When("Only second is working", func() {
|
||||
BeforeEach(func() {
|
||||
testUpstream1 := config.Upstream{Host: "wrong"}
|
||||
|
||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
|
||||
testUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
|
||||
DeferCleanup(testUpstream2.Close)
|
||||
|
||||
upstreams = []config.Upstream{testUpstream1, testUpstream2.Start()}
|
||||
upstreams = []config.Upstream{{Host: "wrong"}, testUpstream2.Start()}
|
||||
})
|
||||
It("Should use result from second one", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
|
|
|
@ -135,9 +135,12 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
|
|||
return nil, err
|
||||
}
|
||||
|
||||
redisClient, redisErr := redis.New(&cfg.Redis)
|
||||
if redisErr != nil && cfg.Redis.Required {
|
||||
return nil, redisErr
|
||||
var redisClient *redis.Client
|
||||
if cfg.Redis.IsEnabled() {
|
||||
redisClient, err = redis.New(ctx, &cfg.Redis)
|
||||
if err != nil && cfg.Redis.Required {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
queryResolver, queryError := createQueryResolver(ctx, cfg, bootstrap, redisClient)
|
||||
|
@ -423,6 +426,11 @@ func (s *Server) registerDNSHandlers() {
|
|||
func (s *Server) printConfiguration() {
|
||||
logger().Info("current configuration:")
|
||||
|
||||
if s.cfg.Redis.IsEnabled() {
|
||||
logger().Info("Redis:")
|
||||
log.WithIndent(logger(), " ", s.cfg.Redis.LogConfig)
|
||||
}
|
||||
|
||||
resolver.ForEach(s.queryResolver, func(res resolver.Resolver) {
|
||||
resolver.LogResolverConfig(res, logger())
|
||||
})
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
package util
|
||||
|
||||
import "context"
|
||||
|
||||
// CtxSend sends a value to a channel while the context isn't done.
|
||||
// If the message is sent, it returns true.
|
||||
// If the context is done or the channel is closed, it returns false.
|
||||
func CtxSend[T any](ctx context.Context, ch chan T, val T) (ok bool) {
|
||||
if ctx == nil || ch == nil || ctx.Err() != nil {
|
||||
ok = false
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
ok = false
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ok = false
|
||||
case ch <- val:
|
||||
ok = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const (
|
||||
testMessage = 1
|
||||
)
|
||||
|
||||
var _ = Describe("Context utils", func() {
|
||||
Describe("CtxSend", func() {
|
||||
var ch chan int
|
||||
BeforeEach(func() {
|
||||
ch = make(chan int, 1)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
ch = nil
|
||||
})
|
||||
|
||||
When("channel is not closed", func() {
|
||||
It("should send value to channel", func(ctx context.Context) {
|
||||
go startReader(ctx, ch)
|
||||
Expect(CtxSend(ctx, ch, testMessage)).Should(BeTrue())
|
||||
}, SpecTimeout(time.Second))
|
||||
})
|
||||
|
||||
When("channel is closed", func() {
|
||||
It("should return false", func(ctx context.Context) {
|
||||
go startReader(ctx, ch)
|
||||
close(ch)
|
||||
Expect(CtxSend(ctx, ch, testMessage)).Should(BeFalse())
|
||||
}, SpecTimeout(time.Second))
|
||||
})
|
||||
|
||||
When("channel is nil", func() {
|
||||
It("should return false", func(ctx context.Context) {
|
||||
Expect(CtxSend(ctx, nil, testMessage)).Should(BeFalse())
|
||||
}, SpecTimeout(time.Second))
|
||||
})
|
||||
|
||||
When("channel is full", func() {
|
||||
It("should wait", func(ctx context.Context) {
|
||||
ch <- testMessage
|
||||
go func(ctx context.Context, ch chan int) {
|
||||
timer := time.NewTimer(time.Millisecond * 200)
|
||||
select {
|
||||
case <-timer.C:
|
||||
startReader(ctx, ch)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}(ctx, ch)
|
||||
Expect(CtxSend(ctx, ch, testMessage)).Should(BeTrue())
|
||||
}, SpecTimeout(time.Second))
|
||||
})
|
||||
|
||||
When("context is done", func() {
|
||||
It("should return false", func(ctx context.Context) {
|
||||
go startReader(ctx, ch)
|
||||
cCtx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
<-cCtx.Done()
|
||||
Expect(CtxSend(cCtx, ch, testMessage)).Should(BeFalse())
|
||||
}, SpecTimeout(time.Second))
|
||||
})
|
||||
|
||||
When("context is terminated", func() {
|
||||
It("should return false", func(ctx context.Context) {
|
||||
ch <- testMessage
|
||||
cCtx, cancel := context.WithCancel(ctx)
|
||||
go func(ctx context.Context, cancel context.CancelFunc) {
|
||||
timer := time.NewTimer(time.Millisecond * 200)
|
||||
select {
|
||||
case <-timer.C:
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}(ctx, cancel)
|
||||
Expect(CtxSend(cCtx, ch, testMessage)).Should(BeFalse())
|
||||
}, SpecTimeout(time.Second))
|
||||
})
|
||||
|
||||
When("context is nil", func() {
|
||||
It("should return false", func(ctx context.Context) {
|
||||
Expect(CtxSend(nil, ch, testMessage)).Should(BeFalse())
|
||||
}, SpecTimeout(time.Second))
|
||||
})
|
||||
|
||||
When("context and channel are nil", func() {
|
||||
It("should return false", func(ctx context.Context) {
|
||||
Expect(CtxSend(nil, nil, testMessage)).Should(BeFalse())
|
||||
}, SpecTimeout(time.Second))
|
||||
})
|
||||
|
||||
When("context is done and channel is closed", func() {
|
||||
It("should return false", func(ctx context.Context) {
|
||||
go startReader(ctx, ch)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
close(ch)
|
||||
Expect(CtxSend(ctx, ch, testMessage)).Should(BeFalse())
|
||||
}, SpecTimeout(time.Second))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func startReader(ctx context.Context, ch <-chan int) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case i, ok := <-ch:
|
||||
if ok {
|
||||
Expect(i).Should(Equal(testMessage))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue