mirror of https://github.com/0xERR0R/blocky.git
Refactoring Redis (#1271)
* RedisConfig -> Redis * moved redis config to seperate file * bugfix in config test during parallel processing * implement config.Configurable in Redis config * use Context in GetRedisCache * use Context in New * caching resolver test fix * use Context in PublishEnabled * use Context in getResponse * remove ctx field * bugfix in api interface test * propperly close channels * set ruler for go files from 80 to 111 * line break because function length is to long * only execute redis.New if it is enabled in config * stabilized flaky tests * Update config/redis.go Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> * Update config/redis_test.go Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> * Update config/redis_test.go Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> * Update config/redis_test.go Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> * Update config/redis.go Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> * Update config/redis_test.go Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com> * fix ruler * redis test refactoring * vscode setting cleanup * removed else if chain * Update redis_test.go * context race fix * test fail on missing seintinel servers * cleanup context usage * cleanup2 * context fixes * added context util * disabled nil context rule for tests * copy paste error ctxSend -> CtxSend * use util.CtxSend * fixed comment * fixed flaky test * failsafe and tests --------- Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>
This commit is contained in:
parent
15bd383460
commit
fda2dbe9df
|
@ -31,19 +31,13 @@
|
||||||
"GitHub.vscode-github-actions"
|
"GitHub.vscode-github-actions"
|
||||||
],
|
],
|
||||||
"settings": {
|
"settings": {
|
||||||
"go.lintFlags": ["--config=${containerWorkspaceFolder}/.golangci.yml"],
|
"go.lintFlags": [
|
||||||
|
"--config=${containerWorkspaceFolder}/.golangci.yml",
|
||||||
|
"--fast"
|
||||||
|
],
|
||||||
"go.alternateTools": {
|
"go.alternateTools": {
|
||||||
"go-langserver": "gopls"
|
"go-langserver": "gopls"
|
||||||
},
|
},
|
||||||
"[go]": {
|
|
||||||
"editor.defaultFormatter": "golang.go"
|
|
||||||
},
|
|
||||||
"[json][jsonc][github-actions-workflow]": {
|
|
||||||
"editor.defaultFormatter": "esbenp.prettier-vscode"
|
|
||||||
},
|
|
||||||
"[markdown]": {
|
|
||||||
"editor.defaultFormatter": "yzhang.markdown-all-in-one"
|
|
||||||
},
|
|
||||||
"markiscodecoverage.searchCriteria": "**/*.lcov",
|
"markiscodecoverage.searchCriteria": "**/*.lcov",
|
||||||
"runItOn": {
|
"runItOn": {
|
||||||
"commands": [
|
"commands": [
|
||||||
|
@ -52,6 +46,15 @@
|
||||||
"cmd": "${workspaceRoot}/.devcontainer/scripts/runItOnGo.sh ${fileDirname} ${workspaceRoot}"
|
"cmd": "${workspaceRoot}/.devcontainer/scripts/runItOnGo.sh ${fileDirname} ${workspaceRoot}"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"[go]": {
|
||||||
|
"editor.defaultFormatter": "golang.go"
|
||||||
|
},
|
||||||
|
"[json][jsonc][github-actions-workflow]": {
|
||||||
|
"editor.defaultFormatter": "esbenp.prettier-vscode"
|
||||||
|
},
|
||||||
|
"[markdown]": {
|
||||||
|
"editor.defaultFormatter": "yzhang.markdown-all-in-one"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,3 +98,7 @@ issues:
|
||||||
- gochecknoinits
|
- gochecknoinits
|
||||||
- gochecknoglobals
|
- gochecknoglobals
|
||||||
- gosec
|
- gosec
|
||||||
|
- path: _test\.go
|
||||||
|
linters:
|
||||||
|
- staticcheck
|
||||||
|
text: "SA1012:"
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
"source.organizeImports": true,
|
"source.organizeImports": true,
|
||||||
"source.fixAll": true
|
"source.fixAll": true
|
||||||
},
|
},
|
||||||
|
"editor.rulers": [120],
|
||||||
"go.showWelcome": false,
|
"go.showWelcome": false,
|
||||||
"go.survey.prompt": false,
|
"go.survey.prompt": false,
|
||||||
"go.useLanguageServer": true,
|
"go.useLanguageServer": true,
|
||||||
|
|
|
@ -29,8 +29,8 @@ type BlockingStatus struct {
|
||||||
|
|
||||||
// BlockingControl interface to control the blocking status
|
// BlockingControl interface to control the blocking status
|
||||||
type BlockingControl interface {
|
type BlockingControl interface {
|
||||||
EnableBlocking()
|
EnableBlocking(ctx context.Context)
|
||||||
DisableBlocking(duration time.Duration, disableGroups []string) error
|
DisableBlocking(ctx context.Context, duration time.Duration, disableGroups []string) error
|
||||||
BlockingStatus() BlockingStatus
|
BlockingStatus() BlockingStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ func NewOpenAPIInterfaceImpl(control BlockingControl,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *OpenAPIInterfaceImpl) DisableBlocking(_ context.Context,
|
func (i *OpenAPIInterfaceImpl) DisableBlocking(ctx context.Context,
|
||||||
request DisableBlockingRequestObject,
|
request DisableBlockingRequestObject,
|
||||||
) (DisableBlockingResponseObject, error) {
|
) (DisableBlockingResponseObject, error) {
|
||||||
var (
|
var (
|
||||||
|
@ -91,7 +91,7 @@ func (i *OpenAPIInterfaceImpl) DisableBlocking(_ context.Context,
|
||||||
groups = strings.Split(*request.Params.Groups, ",")
|
groups = strings.Split(*request.Params.Groups, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = i.control.DisableBlocking(duration, groups)
|
err = i.control.DisableBlocking(ctx, duration, groups)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DisableBlocking400TextResponse(log.EscapeInput(err.Error())), nil
|
return DisableBlocking400TextResponse(log.EscapeInput(err.Error())), nil
|
||||||
|
@ -100,9 +100,9 @@ func (i *OpenAPIInterfaceImpl) DisableBlocking(_ context.Context,
|
||||||
return DisableBlocking200Response{}, nil
|
return DisableBlocking200Response{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *OpenAPIInterfaceImpl) EnableBlocking(_ context.Context, _ EnableBlockingRequestObject,
|
func (i *OpenAPIInterfaceImpl) EnableBlocking(ctx context.Context, _ EnableBlockingRequestObject,
|
||||||
) (EnableBlockingResponseObject, error) {
|
) (EnableBlockingResponseObject, error) {
|
||||||
i.control.EnableBlocking()
|
i.control.EnableBlocking(ctx)
|
||||||
|
|
||||||
return EnableBlocking200Response{}, nil
|
return EnableBlocking200Response{}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,11 +37,11 @@ func (m *ListRefreshMock) RefreshLists() error {
|
||||||
return args.Error(0)
|
return args.Error(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *BlockingControlMock) EnableBlocking() {
|
func (m *BlockingControlMock) EnableBlocking(_ context.Context) {
|
||||||
_ = m.Called()
|
_ = m.Called()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *BlockingControlMock) DisableBlocking(t time.Duration, g []string) error {
|
func (m *BlockingControlMock) DisableBlocking(_ context.Context, t time.Duration, g []string) error {
|
||||||
args := m.Called(t, g)
|
args := m.Called(t, g)
|
||||||
|
|
||||||
return args.Error(0)
|
return args.Error(0)
|
||||||
|
|
|
@ -219,7 +219,7 @@ type Config struct {
|
||||||
Caching CachingConfig `yaml:"caching"`
|
Caching CachingConfig `yaml:"caching"`
|
||||||
QueryLog QueryLogConfig `yaml:"queryLog"`
|
QueryLog QueryLogConfig `yaml:"queryLog"`
|
||||||
Prometheus MetricsConfig `yaml:"prometheus"`
|
Prometheus MetricsConfig `yaml:"prometheus"`
|
||||||
Redis RedisConfig `yaml:"redis"`
|
Redis Redis `yaml:"redis"`
|
||||||
Log log.Config `yaml:"log"`
|
Log log.Config `yaml:"log"`
|
||||||
Ports PortsConfig `yaml:"ports"`
|
Ports PortsConfig `yaml:"ports"`
|
||||||
MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"`
|
MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"`
|
||||||
|
@ -280,20 +280,6 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// RedisConfig configuration for the redis connection
|
|
||||||
type RedisConfig struct {
|
|
||||||
Address string `yaml:"address"`
|
|
||||||
Username string `yaml:"username" default:""`
|
|
||||||
Password string `yaml:"password" default:""`
|
|
||||||
Database int `yaml:"database" default:"0"`
|
|
||||||
Required bool `yaml:"required" default:"false"`
|
|
||||||
ConnectionAttempts int `yaml:"connectionAttempts" default:"3"`
|
|
||||||
ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"`
|
|
||||||
SentinelUsername string `yaml:"sentinelUsername" default:""`
|
|
||||||
SentinelPassword string `yaml:"sentinelPassword" default:""`
|
|
||||||
SentinelAddresses []string `yaml:"sentinelAddresses"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type (
|
type (
|
||||||
FQDNOnly = toEnable
|
FQDNOnly = toEnable
|
||||||
EDE = toEnable
|
EDE = toEnable
|
||||||
|
|
|
@ -169,7 +169,7 @@ var _ = Describe("Config", func() {
|
||||||
err = writeConfigDir(tmpDir)
|
err = writeConfigDir(tmpDir)
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
_, err := LoadConfig(tmpDir.Path, true)
|
c, err = LoadConfig(tmpDir.Path, true)
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
defaultTestFileConfig(c)
|
defaultTestFileConfig(c)
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Redis configuration for the redis connection
|
||||||
|
type Redis struct {
|
||||||
|
Address string `yaml:"address"`
|
||||||
|
Username string `yaml:"username" default:""`
|
||||||
|
Password string `yaml:"password" default:""`
|
||||||
|
Database int `yaml:"database" default:"0"`
|
||||||
|
Required bool `yaml:"required" default:"false"`
|
||||||
|
ConnectionAttempts int `yaml:"connectionAttempts" default:"3"`
|
||||||
|
ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"`
|
||||||
|
SentinelUsername string `yaml:"sentinelUsername" default:""`
|
||||||
|
SentinelPassword string `yaml:"sentinelPassword" default:""`
|
||||||
|
SentinelAddresses []string `yaml:"sentinelAddresses"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEnabled implements `config.Configurable`
|
||||||
|
func (c *Redis) IsEnabled() bool {
|
||||||
|
return c.Address != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogConfig implements `config.Configurable`
|
||||||
|
func (c *Redis) LogConfig(logger *logrus.Entry) {
|
||||||
|
if len(c.SentinelAddresses) == 0 {
|
||||||
|
logger.Info("address: ", c.Address)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("username: ", c.Username)
|
||||||
|
logger.Info("password: ", obfuscatePassword(c.Password))
|
||||||
|
logger.Info("database: ", c.Database)
|
||||||
|
logger.Info("required: ", c.Required)
|
||||||
|
logger.Info("connectionAttempts: ", c.ConnectionAttempts)
|
||||||
|
logger.Info("connectionCooldown: ", c.ConnectionCooldown)
|
||||||
|
|
||||||
|
if len(c.SentinelAddresses) > 0 {
|
||||||
|
logger.Info("sentinel:")
|
||||||
|
logger.Info(" master: ", c.Address)
|
||||||
|
logger.Info(" username: ", c.SentinelUsername)
|
||||||
|
logger.Info(" password: ", obfuscatePassword(c.SentinelPassword))
|
||||||
|
logger.Info(" addresses:")
|
||||||
|
|
||||||
|
for _, addr := range c.SentinelAddresses {
|
||||||
|
logger.Info(" - ", addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// obfuscatePassword replaces all characters of a password except the first and last with *
|
||||||
|
func obfuscatePassword(pass string) string {
|
||||||
|
return strings.Repeat("*", len(pass))
|
||||||
|
}
|
|
@ -0,0 +1,104 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/0xERR0R/blocky/log"
|
||||||
|
"github.com/creasty/defaults"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Redis", func() {
|
||||||
|
var (
|
||||||
|
c Redis
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
suiteBeforeEach()
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
err = defaults.Set(&c)
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("IsEnabled", func() {
|
||||||
|
When("all fields are default", func() {
|
||||||
|
It("should be disabled", func() {
|
||||||
|
Expect(c.IsEnabled()).Should(BeFalse())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
When("Address is set", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
c.Address = "localhost:6379"
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should be enabled", func() {
|
||||||
|
Expect(c.IsEnabled()).Should(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("LogConfig", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
logger, hook = log.NewMockEntry()
|
||||||
|
})
|
||||||
|
|
||||||
|
When("all fields are default", func() {
|
||||||
|
It("should log default values", func() {
|
||||||
|
c.LogConfig(logger)
|
||||||
|
|
||||||
|
Expect(hook.Messages).Should(
|
||||||
|
SatisfyAll(ContainElement(ContainSubstring("address: ")),
|
||||||
|
ContainElement(ContainSubstring("username: ")),
|
||||||
|
ContainElement(ContainSubstring("password: ")),
|
||||||
|
ContainElement(ContainSubstring("database: ")),
|
||||||
|
ContainElement(ContainSubstring("required: ")),
|
||||||
|
ContainElement(ContainSubstring("connectionAttempts: ")),
|
||||||
|
ContainElement(ContainSubstring("connectionCooldown: "))))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
When("Address is set", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
c.Address = "localhost:6379"
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should log address", func() {
|
||||||
|
c.LogConfig(logger)
|
||||||
|
|
||||||
|
Expect(hook.Messages).Should(ContainElement(ContainSubstring("address: localhost:6379")))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
When("SentinelAddresses is set", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
c.SentinelAddresses = []string{"localhost:26379", "localhost:26380"}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should log sentinel addresses", func() {
|
||||||
|
c.LogConfig(logger)
|
||||||
|
|
||||||
|
Expect(hook.Messages).Should(
|
||||||
|
SatisfyAll(
|
||||||
|
ContainElement(ContainSubstring("sentinel:")),
|
||||||
|
ContainElement(ContainSubstring(" addresses:")),
|
||||||
|
ContainElement(ContainSubstring(" - localhost:26379")),
|
||||||
|
ContainElement(ContainSubstring(" - localhost:26380"))))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("obfuscatePassword", func() {
|
||||||
|
When("password is empty", func() {
|
||||||
|
It("should return empty string", func() {
|
||||||
|
Expect(obfuscatePassword("")).Should(Equal(""))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
When("password is not empty", func() {
|
||||||
|
It("should return obfuscated password", func() {
|
||||||
|
Expect(obfuscatePassword("test123")).Should(Equal("*******"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
128
redis/redis.go
128
redis/redis.go
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
"github.com/0xERR0R/blocky/log"
|
"github.com/0xERR0R/blocky/log"
|
||||||
"github.com/0xERR0R/blocky/model"
|
"github.com/0xERR0R/blocky/model"
|
||||||
|
"github.com/0xERR0R/blocky/util"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
@ -56,10 +57,9 @@ type EnabledMessage struct {
|
||||||
|
|
||||||
// Client for redis communication
|
// Client for redis communication
|
||||||
type Client struct {
|
type Client struct {
|
||||||
config *config.RedisConfig
|
config *config.Redis
|
||||||
client *redis.Client
|
client *redis.Client
|
||||||
l *logrus.Entry
|
l *logrus.Entry
|
||||||
ctx context.Context
|
|
||||||
id []byte
|
id []byte
|
||||||
sendBuffer chan *bufferMessage
|
sendBuffer chan *bufferMessage
|
||||||
CacheChannel chan *CacheMessage
|
CacheChannel chan *CacheMessage
|
||||||
|
@ -67,15 +67,15 @@ type Client struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new redis client
|
// New creates a new redis client
|
||||||
func New(cfg *config.RedisConfig) (*Client, error) {
|
func New(ctx context.Context, cfg *config.Redis) (*Client, error) {
|
||||||
// disable redis if no address is provided
|
// disable redis if no address is provided
|
||||||
if cfg == nil || len(cfg.Address) == 0 {
|
if cfg == nil || len(cfg.Address) == 0 {
|
||||||
return nil, nil //nolint:nilnil
|
return nil, nil //nolint:nilnil
|
||||||
}
|
}
|
||||||
|
|
||||||
var rdb *redis.Client
|
var baseClient *redis.Client
|
||||||
if len(cfg.SentinelAddresses) > 0 {
|
if len(cfg.SentinelAddresses) > 0 {
|
||||||
rdb = redis.NewFailoverClient(&redis.FailoverOptions{
|
baseClient = redis.NewFailoverClient(&redis.FailoverOptions{
|
||||||
MasterName: cfg.Address,
|
MasterName: cfg.Address,
|
||||||
SentinelUsername: cfg.Username,
|
SentinelUsername: cfg.Username,
|
||||||
SentinelPassword: cfg.SentinelPassword,
|
SentinelPassword: cfg.SentinelPassword,
|
||||||
|
@ -87,7 +87,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
|
||||||
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
|
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
rdb = redis.NewClient(&redis.Options{
|
baseClient = redis.NewClient(&redis.Options{
|
||||||
Addr: cfg.Address,
|
Addr: cfg.Address,
|
||||||
Username: cfg.Username,
|
Username: cfg.Username,
|
||||||
Password: cfg.Password,
|
Password: cfg.Password,
|
||||||
|
@ -97,7 +97,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
rdb := baseClient.WithContext(ctx)
|
||||||
|
|
||||||
_, err := rdb.Ping(ctx).Result()
|
_, err := rdb.Ping(ctx).Result()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -110,7 +110,6 @@ func New(cfg *config.RedisConfig) (*Client, error) {
|
||||||
config: cfg,
|
config: cfg,
|
||||||
client: rdb,
|
client: rdb,
|
||||||
l: log.PrefixedLog("redis"),
|
l: log.PrefixedLog("redis"),
|
||||||
ctx: ctx,
|
|
||||||
id: id,
|
id: id,
|
||||||
sendBuffer: make(chan *bufferMessage, chanCap),
|
sendBuffer: make(chan *bufferMessage, chanCap),
|
||||||
CacheChannel: make(chan *CacheMessage, chanCap),
|
CacheChannel: make(chan *CacheMessage, chanCap),
|
||||||
|
@ -118,7 +117,7 @@ func New(cfg *config.RedisConfig) (*Client, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// start channel handling go routine
|
// start channel handling go routine
|
||||||
err = res.startup()
|
err = res.startup(ctx)
|
||||||
|
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
|
@ -137,7 +136,7 @@ func (c *Client) PublishCache(key string, message *dns.Msg) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) PublishEnabled(state *EnabledMessage) {
|
func (c *Client) PublishEnabled(ctx context.Context, state *EnabledMessage) {
|
||||||
binState, sErr := json.Marshal(state)
|
binState, sErr := json.Marshal(state)
|
||||||
if sErr == nil {
|
if sErr == nil {
|
||||||
binMsg, mErr := json.Marshal(redisMessage{
|
binMsg, mErr := json.Marshal(redisMessage{
|
||||||
|
@ -147,22 +146,30 @@ func (c *Client) PublishEnabled(state *EnabledMessage) {
|
||||||
})
|
})
|
||||||
|
|
||||||
if mErr == nil {
|
if mErr == nil {
|
||||||
c.client.Publish(c.ctx, SyncChannelName, binMsg)
|
c.client.Publish(ctx, SyncChannelName, binMsg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRedisCache reads the redis cache and publish it to the channel
|
// GetRedisCache reads the redis cache and publish it to the channel
|
||||||
func (c *Client) GetRedisCache() {
|
func (c *Client) GetRedisCache(ctx context.Context) {
|
||||||
c.l.Debug("GetRedisCache")
|
c.l.Debug("GetRedisCache")
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
iter := c.client.Scan(c.ctx, 0, prefixKey("*"), 0).Iterator()
|
iter := c.client.Scan(ctx, 0, prefixKey("*"), 0).Iterator()
|
||||||
for iter.Next(c.ctx) {
|
if err := iter.Err(); err != nil {
|
||||||
response, err := c.getResponse(iter.Val())
|
c.l.Error("GetRedisCache ", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for iter.Next(ctx) {
|
||||||
|
response, err := c.getResponse(ctx, iter.Val())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if response != nil {
|
if response != nil {
|
||||||
c.CacheChannel <- response
|
if !util.CtxSend(ctx, c.CacheChannel, response) {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
c.l.Error("GetRedisCache ", err)
|
c.l.Error("GetRedisCache ", err)
|
||||||
|
@ -172,10 +179,10 @@ func (c *Client) GetRedisCache() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// startup starts a new goroutine for subscription and translation
|
// startup starts a new goroutine for subscription and translation
|
||||||
func (c *Client) startup() error {
|
func (c *Client) startup(ctx context.Context) error {
|
||||||
ps := c.client.Subscribe(c.ctx, SyncChannelName)
|
ps := c.client.Subscribe(ctx, SyncChannelName)
|
||||||
|
|
||||||
_, err := ps.Receive(c.ctx)
|
_, err := ps.Receive(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
|
@ -186,11 +193,16 @@ func (c *Client) startup() error {
|
||||||
|
|
||||||
if msg != nil && len(msg.Payload) > 0 {
|
if msg != nil && len(msg.Payload) > 0 {
|
||||||
// message is not empty
|
// message is not empty
|
||||||
c.processReceivedMessage(msg)
|
c.processReceivedMessage(ctx, msg)
|
||||||
}
|
}
|
||||||
// publish message from buffer
|
// publish message from buffer
|
||||||
case s := <-c.sendBuffer:
|
case s := <-c.sendBuffer:
|
||||||
c.publishMessageFromBuffer(s)
|
c.publishMessageFromBuffer(ctx, s)
|
||||||
|
// context is done
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.client.Close()
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -199,7 +211,7 @@ func (c *Client) startup() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) publishMessageFromBuffer(s *bufferMessage) {
|
func (c *Client) publishMessageFromBuffer(ctx context.Context, s *bufferMessage) {
|
||||||
origRes := s.Message
|
origRes := s.Message
|
||||||
origRes.Compress = true
|
origRes.Compress = true
|
||||||
binRes, pErr := origRes.Pack()
|
binRes, pErr := origRes.Pack()
|
||||||
|
@ -213,61 +225,61 @@ func (c *Client) publishMessageFromBuffer(s *bufferMessage) {
|
||||||
})
|
})
|
||||||
|
|
||||||
if mErr == nil {
|
if mErr == nil {
|
||||||
c.client.Publish(c.ctx, SyncChannelName, binMsg)
|
c.client.Publish(ctx, SyncChannelName, binMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.client.Set(c.ctx,
|
c.client.Set(ctx,
|
||||||
prefixKey(s.Key),
|
prefixKey(s.Key),
|
||||||
binRes,
|
binRes,
|
||||||
c.getTTL(origRes))
|
c.getTTL(origRes))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) processReceivedMessage(msg *redis.Message) {
|
func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) {
|
||||||
var rm redisMessage
|
var rm redisMessage
|
||||||
|
|
||||||
err := json.Unmarshal([]byte(msg.Payload), &rm)
|
if err := json.Unmarshal([]byte(msg.Payload), &rm); err != nil {
|
||||||
if err == nil {
|
c.l.Error("Processing error: ", err)
|
||||||
// message was sent from a different blocky instance
|
|
||||||
if !bytes.Equal(rm.Client, c.id) {
|
|
||||||
switch rm.Type {
|
|
||||||
case messageTypeCache:
|
|
||||||
var cm *CacheMessage
|
|
||||||
|
|
||||||
cm, err = convertMessage(&rm, 0)
|
return
|
||||||
if err == nil {
|
}
|
||||||
c.CacheChannel <- cm
|
|
||||||
}
|
// message was sent from a different blocky instance
|
||||||
case messageTypeEnable:
|
if !bytes.Equal(rm.Client, c.id) {
|
||||||
err = c.processEnabledMessage(&rm)
|
switch rm.Type {
|
||||||
default:
|
case messageTypeCache:
|
||||||
c.l.Warn("Unknown message type: ", rm.Type)
|
var cm *CacheMessage
|
||||||
|
|
||||||
|
cm, err := convertMessage(&rm, 0)
|
||||||
|
if err != nil {
|
||||||
|
c.l.Error("Processing CacheMessage error: ", err)
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
util.CtxSend(ctx, c.CacheChannel, cm)
|
||||||
|
case messageTypeEnable:
|
||||||
|
var msg EnabledMessage
|
||||||
|
|
||||||
|
if err := json.Unmarshal(rm.Message, &msg); err != nil {
|
||||||
|
c.l.Error("Processing EnabledMessage error: ", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.CtxSend(ctx, c.EnabledChannel, &msg)
|
||||||
|
default:
|
||||||
|
c.l.Warn("Unknown message type: ", rm.Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
c.l.Error("Processing error: ", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) processEnabledMessage(redisMsg *redisMessage) error {
|
|
||||||
var msg EnabledMessage
|
|
||||||
|
|
||||||
err := json.Unmarshal(redisMsg.Message, &msg)
|
|
||||||
if err == nil {
|
|
||||||
c.EnabledChannel <- &msg
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getResponse returns model.Response for a key
|
// getResponse returns model.Response for a key
|
||||||
func (c *Client) getResponse(key string) (*CacheMessage, error) {
|
func (c *Client) getResponse(ctx context.Context, key string) (*CacheMessage, error) {
|
||||||
resp, err := c.client.Get(c.ctx, key).Result()
|
resp, err := c.client.Get(ctx, key).Result()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var ttl time.Duration
|
var ttl time.Duration
|
||||||
ttl, err = c.client.TTL(c.ctx, key).Result()
|
ttl, err = c.client.TTL(ctx, key).Result()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var result *CacheMessage
|
var result *CacheMessage
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package redis
|
package redis
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -18,76 +19,75 @@ const (
|
||||||
exampleComKey = CacheStorePrefix + "example.com"
|
exampleComKey = CacheStorePrefix + "example.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
redisServer *miniredis.Miniredis
|
|
||||||
redisClient *Client
|
|
||||||
redisConfig *config.RedisConfig
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Redis client", func() {
|
var _ = Describe("Redis client", func() {
|
||||||
|
var (
|
||||||
|
redisConfig *config.Redis
|
||||||
|
|
||||||
|
redisClient *Client
|
||||||
|
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
redisServer, err = miniredis.Run()
|
var rcfg config.Redis
|
||||||
|
Expect(defaults.Set(&rcfg)).Should(Succeed())
|
||||||
Expect(err).Should(Succeed())
|
|
||||||
|
|
||||||
DeferCleanup(redisServer.Close)
|
|
||||||
|
|
||||||
var rcfg config.RedisConfig
|
|
||||||
err = defaults.Set(&rcfg)
|
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
|
||||||
|
|
||||||
rcfg.Address = redisServer.Addr()
|
|
||||||
redisConfig = &rcfg
|
redisConfig = &rcfg
|
||||||
redisClient, err = New(redisConfig)
|
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
|
||||||
Expect(redisClient).ShouldNot(BeNil())
|
|
||||||
})
|
})
|
||||||
|
|
||||||
Describe("Client creation", func() {
|
Describe("Client creation", func() {
|
||||||
When("redis configuration has no address", func() {
|
When("redis configuration has no address", func() {
|
||||||
It("should return nil without error", func() {
|
It("should return nil without error", func(ctx context.Context) {
|
||||||
var rcfg config.RedisConfig
|
Expect(New(ctx, redisConfig)).Should(BeNil())
|
||||||
err = defaults.Set(&rcfg)
|
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
|
||||||
|
|
||||||
Expect(New(&rcfg)).Should(BeNil())
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
When("redis configuration has invalid address", func() {
|
When("redis configuration has invalid address", func() {
|
||||||
It("should fail with error", func() {
|
BeforeEach(func() {
|
||||||
var rcfg config.RedisConfig
|
redisConfig.Address = "127.0.0.1:0"
|
||||||
err = defaults.Set(&rcfg)
|
})
|
||||||
Expect(err).Should(Succeed())
|
|
||||||
|
|
||||||
rcfg.Address = "127.0.0.1:0"
|
|
||||||
|
|
||||||
_, err = New(&rcfg)
|
|
||||||
|
|
||||||
|
It("should fail with error", func(ctx context.Context) {
|
||||||
|
_, err = New(ctx, redisConfig)
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
When("sentinel is enabled without servers", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
redisConfig.Address = "test"
|
||||||
|
redisConfig.SentinelAddresses = []string{"127.0.0.1:0"}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should fail with error", func(ctx context.Context) {
|
||||||
|
_, err = New(ctx, redisConfig)
|
||||||
|
Expect(err).Should(HaveOccurred())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
When("redis configuration has invalid password", func() {
|
When("redis configuration has invalid password", func() {
|
||||||
It("should fail with error", func() {
|
BeforeEach(func() {
|
||||||
var rcfg config.RedisConfig
|
setupRedisServer(redisConfig)
|
||||||
err = defaults.Set(&rcfg)
|
redisConfig.Password = "wrong"
|
||||||
Expect(err).Should(Succeed())
|
})
|
||||||
|
|
||||||
rcfg.Address = redisServer.Addr()
|
|
||||||
rcfg.Password = "wrong"
|
|
||||||
|
|
||||||
_, err = New(&rcfg)
|
|
||||||
|
|
||||||
|
It("should fail with error", func(ctx context.Context) {
|
||||||
|
_, err = New(ctx, redisConfig)
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Describe("Publish message", func() {
|
Describe("Publish message", func() {
|
||||||
|
var redisServer *miniredis.Miniredis
|
||||||
|
BeforeEach(func() {
|
||||||
|
redisServer = setupRedisServer(redisConfig)
|
||||||
|
})
|
||||||
|
|
||||||
When("Redis client publishes 'cache' message", func() {
|
When("Redis client publishes 'cache' message", func() {
|
||||||
It("One new entry with TTL > 0 should be persisted in the database", func() {
|
It("One new entry with TTL > 0 should be persisted in the database", func(ctx context.Context) {
|
||||||
|
redisClient, err = New(ctx, redisConfig)
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
By("Database is empty", func() {
|
By("Database is empty", func() {
|
||||||
Eventually(func() []string {
|
Eventually(func() []string {
|
||||||
return redisServer.DB(redisConfig.Database).Keys()
|
return redisServer.DB(redisConfig.Database).Keys()
|
||||||
|
@ -112,7 +112,10 @@ var _ = Describe("Redis client", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
It("One new entry with default TTL should be persisted in the database", func() {
|
It("One new entry with default TTL should be persisted in the database", func(ctx context.Context) {
|
||||||
|
redisClient, err = New(ctx, redisConfig)
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
By("Database is empty", func() {
|
By("Database is empty", func() {
|
||||||
Eventually(func() []string {
|
Eventually(func() []string {
|
||||||
return redisServer.DB(redisConfig.Database).Keys()
|
return redisServer.DB(redisConfig.Database).Keys()
|
||||||
|
@ -138,20 +141,30 @@ var _ = Describe("Redis client", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
When("Redis client publishes 'enabled' message", func() {
|
When("Redis client publishes 'enabled' message", func() {
|
||||||
It("should propagate the message over redis", func() {
|
It("should propagate the message over redis", func(ctx context.Context) {
|
||||||
redisClient.PublishEnabled(&EnabledMessage{
|
redisClient, err = New(ctx, redisConfig)
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
|
redisClient.PublishEnabled(ctx, &EnabledMessage{
|
||||||
State: true,
|
State: true,
|
||||||
})
|
})
|
||||||
Eventually(func() map[string]int {
|
Eventually(func() map[string]int {
|
||||||
return redisServer.PubSubNumSub(SyncChannelName)
|
return redisServer.PubSubNumSub(SyncChannelName)
|
||||||
}).Should(HaveLen(1))
|
}).Should(HaveLen(1))
|
||||||
})
|
}, SpecTimeout(time.Second*6))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Describe("Receive message", func() {
|
Describe("Receive message", func() {
|
||||||
|
var redisServer *miniredis.Miniredis
|
||||||
|
BeforeEach(func() {
|
||||||
|
redisServer = setupRedisServer(redisConfig)
|
||||||
|
})
|
||||||
When("'enabled' message is received", func() {
|
When("'enabled' message is received", func() {
|
||||||
It("should propagate the message over the channel", func() {
|
It("should propagate the message over the channel", func(ctx context.Context) {
|
||||||
|
redisClient, err = New(ctx, redisConfig)
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
var binState []byte
|
var binState []byte
|
||||||
binState, err = json.Marshal(EnabledMessage{State: true})
|
binState, err = json.Marshal(EnabledMessage{State: true})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
@ -179,7 +192,10 @@ var _ = Describe("Redis client", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
When("'cache' message is received", func() {
|
When("'cache' message is received", func() {
|
||||||
It("should propagate the message over the channel", func() {
|
It("should propagate the message over the channel", func(ctx context.Context) {
|
||||||
|
redisClient, err = New(ctx, redisConfig)
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
|
res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123")
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
@ -209,10 +225,13 @@ var _ = Describe("Redis client", func() {
|
||||||
Eventually(func() chan *CacheMessage {
|
Eventually(func() chan *CacheMessage {
|
||||||
return redisClient.CacheChannel
|
return redisClient.CacheChannel
|
||||||
}).Should(HaveLen(lenE + 1))
|
}).Should(HaveLen(lenE + 1))
|
||||||
})
|
}, SpecTimeout(time.Second*6))
|
||||||
})
|
})
|
||||||
When("wrong data is received", func() {
|
When("wrong data is received", func() {
|
||||||
It("should not propagate the message over the channel if data is wrong", func() {
|
It("should not propagate the message over the channel if data is wrong", func(ctx context.Context) {
|
||||||
|
redisClient, err = New(ctx, redisConfig)
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
var id []byte
|
var id []byte
|
||||||
id, err = uuid.New().MarshalBinary()
|
id, err = uuid.New().MarshalBinary()
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
@ -239,8 +258,11 @@ var _ = Describe("Redis client", func() {
|
||||||
Eventually(func() chan *CacheMessage {
|
Eventually(func() chan *CacheMessage {
|
||||||
return redisClient.CacheChannel
|
return redisClient.CacheChannel
|
||||||
}).Should(HaveLen(lenC))
|
}).Should(HaveLen(lenC))
|
||||||
})
|
}, SpecTimeout(time.Second*6))
|
||||||
It("should not propagate the message over the channel if type is wrong", func() {
|
It("should not propagate the message over the channel if type is wrong", func(ctx context.Context) {
|
||||||
|
redisClient, err = New(ctx, redisConfig)
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
var id []byte
|
var id []byte
|
||||||
id, err = uuid.New().MarshalBinary()
|
id, err = uuid.New().MarshalBinary()
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
@ -269,13 +291,20 @@ var _ = Describe("Redis client", func() {
|
||||||
Eventually(func() chan *CacheMessage {
|
Eventually(func() chan *CacheMessage {
|
||||||
return redisClient.CacheChannel
|
return redisClient.CacheChannel
|
||||||
}).Should(HaveLen(lenC))
|
}).Should(HaveLen(lenC))
|
||||||
})
|
}, SpecTimeout(time.Second*6))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Describe("Read the redis cache and publish it to the channel", func() {
|
Describe("Read the redis cache and publish it to the channel", func() {
|
||||||
|
var redisServer *miniredis.Miniredis
|
||||||
|
BeforeEach(func() {
|
||||||
|
redisServer = setupRedisServer(redisConfig)
|
||||||
|
})
|
||||||
When("GetRedisCache is called with valid database entries", func() {
|
When("GetRedisCache is called with valid database entries", func() {
|
||||||
It("Should read data from Redis and propagate it via cache channel", func() {
|
It("Should read data from Redis and propagate it via cache channel", func(ctx context.Context) {
|
||||||
|
redisClient, err = New(ctx, redisConfig)
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
By("Database is empty", func() {
|
By("Database is empty", func() {
|
||||||
Eventually(func() []string {
|
Eventually(func() []string {
|
||||||
return redisServer.DB(redisConfig.Database).Keys()
|
return redisServer.DB(redisConfig.Database).Keys()
|
||||||
|
@ -299,18 +328,30 @@ var _ = Describe("Redis client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("call GetRedisCache - It should read one entry from redis and propagate it via channel", func() {
|
By("call GetRedisCache - It should read one entry from redis and propagate it via channel", func() {
|
||||||
redisClient.GetRedisCache()
|
redisClient.GetRedisCache(ctx)
|
||||||
|
|
||||||
Eventually(redisClient.CacheChannel).Should(HaveLen(1))
|
Eventually(redisClient.CacheChannel).Should(HaveLen(1))
|
||||||
})
|
})
|
||||||
})
|
}, SpecTimeout(time.Second*4))
|
||||||
})
|
})
|
||||||
When("GetRedisCache is called and database contains not valid entry", func() {
|
When("GetRedisCache is called and database contains not valid entry", func() {
|
||||||
It("Should do nothing (only log error)", func() {
|
It("Should do nothing (only log error)", func(ctx context.Context) {
|
||||||
|
redisClient, err = New(ctx, redisConfig)
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
Expect(redisServer.DB(redisConfig.Database).Set(CacheStorePrefix+"test", "test")).Should(Succeed())
|
Expect(redisServer.DB(redisConfig.Database).Set(CacheStorePrefix+"test", "test")).Should(Succeed())
|
||||||
redisClient.GetRedisCache()
|
redisClient.GetRedisCache(ctx)
|
||||||
Consistently(redisClient.CacheChannel).Should(BeEmpty())
|
Consistently(redisClient.CacheChannel).Should(BeEmpty())
|
||||||
})
|
}, SpecTimeout(time.Second*2))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
func setupRedisServer(cfg *config.Redis) *miniredis.Miniredis {
|
||||||
|
redisServer, err := miniredis.Run()
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
DeferCleanup(redisServer.Close)
|
||||||
|
cfg.Address = redisServer.Addr()
|
||||||
|
|
||||||
|
return redisServer
|
||||||
|
}
|
||||||
|
|
|
@ -183,7 +183,7 @@ func (r *BlockingResolver) redisSubscriber(ctx context.Context) {
|
||||||
if em.State {
|
if em.State {
|
||||||
r.internalEnableBlocking()
|
r.internalEnableBlocking()
|
||||||
} else {
|
} else {
|
||||||
err := r.internalDisableBlocking(em.Duration, em.Groups)
|
err := r.internalDisableBlocking(ctx, em.Duration, em.Groups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.log().Warn("Blocking couldn't be disabled:", err)
|
r.log().Warn("Blocking couldn't be disabled:", err)
|
||||||
}
|
}
|
||||||
|
@ -216,11 +216,11 @@ func (r *BlockingResolver) retrieveAllBlockingGroups() []string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableBlocking enables the blocking against the blacklists
|
// EnableBlocking enables the blocking against the blacklists
|
||||||
func (r *BlockingResolver) EnableBlocking() {
|
func (r *BlockingResolver) EnableBlocking(ctx context.Context) {
|
||||||
r.internalEnableBlocking()
|
r.internalEnableBlocking()
|
||||||
|
|
||||||
if r.redisClient != nil {
|
if r.redisClient != nil {
|
||||||
r.redisClient.PublishEnabled(&redis.EnabledMessage{State: true})
|
r.redisClient.PublishEnabled(ctx, &redis.EnabledMessage{State: true})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,10 +236,10 @@ func (r *BlockingResolver) internalEnableBlocking() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisableBlocking deactivates the blocking for a particular duration (or forever if 0).
|
// DisableBlocking deactivates the blocking for a particular duration (or forever if 0).
|
||||||
func (r *BlockingResolver) DisableBlocking(duration time.Duration, disableGroups []string) error {
|
func (r *BlockingResolver) DisableBlocking(ctx context.Context, duration time.Duration, disableGroups []string) error {
|
||||||
err := r.internalDisableBlocking(duration, disableGroups)
|
err := r.internalDisableBlocking(ctx, duration, disableGroups)
|
||||||
if err == nil && r.redisClient != nil {
|
if err == nil && r.redisClient != nil {
|
||||||
r.redisClient.PublishEnabled(&redis.EnabledMessage{
|
r.redisClient.PublishEnabled(ctx, &redis.EnabledMessage{
|
||||||
State: false,
|
State: false,
|
||||||
Duration: duration,
|
Duration: duration,
|
||||||
Groups: disableGroups,
|
Groups: disableGroups,
|
||||||
|
@ -249,7 +249,9 @@ func (r *BlockingResolver) DisableBlocking(duration time.Duration, disableGroups
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *BlockingResolver) internalDisableBlocking(duration time.Duration, disableGroups []string) error {
|
func (r *BlockingResolver) internalDisableBlocking(ctx context.Context, duration time.Duration,
|
||||||
|
disableGroups []string,
|
||||||
|
) error {
|
||||||
s := r.status
|
s := r.status
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
defer s.lock.Unlock()
|
defer s.lock.Unlock()
|
||||||
|
@ -280,7 +282,7 @@ func (r *BlockingResolver) internalDisableBlocking(duration time.Duration, disab
|
||||||
log.Log().Infof("disable blocking for %s for group(s) '%s'", duration,
|
log.Log().Infof("disable blocking for %s for group(s) '%s'", duration,
|
||||||
log.EscapeInput(strings.Join(s.disabledGroups, "; ")))
|
log.EscapeInput(strings.Join(s.disabledGroups, "; ")))
|
||||||
s.enableTimer = time.AfterFunc(duration, func() {
|
s.enableTimer = time.AfterFunc(duration, func() {
|
||||||
r.EnableBlocking()
|
r.EnableBlocking(ctx)
|
||||||
log.Log().Info("blocking enabled again")
|
log.Log().Info("blocking enabled again")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -842,7 +842,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("Calling Rest API to deactivate all groups", func() {
|
By("Calling Rest API to deactivate all groups", func() {
|
||||||
err := sut.DisableBlocking(0, []string{})
|
err := sut.DisableBlocking(context.TODO(), 0, []string{})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -875,7 +875,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("Calling Rest API to deactivate only defaultGroup", func() {
|
By("Calling Rest API to deactivate only defaultGroup", func() {
|
||||||
err := sut.DisableBlocking(0, []string{"defaultGroup"})
|
err := sut.DisableBlocking(context.TODO(), 0, []string{"defaultGroup"})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -935,7 +935,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
enabled <- state
|
enabled <- state
|
||||||
})
|
})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
err = sut.DisableBlocking(500*time.Millisecond, []string{})
|
err = sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Eventually(enabled, "1s").Should(Receive(BeFalse()))
|
Eventually(enabled, "1s").Should(Receive(BeFalse()))
|
||||||
})
|
})
|
||||||
|
@ -1025,7 +1025,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
enabled <- false
|
enabled <- false
|
||||||
})
|
})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
err = sut.DisableBlocking(500*time.Millisecond, []string{"group1"})
|
err = sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{"group1"})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Eventually(enabled, "1s").Should(Receive(BeFalse()))
|
Eventually(enabled, "1s").Should(Receive(BeFalse()))
|
||||||
})
|
})
|
||||||
|
@ -1086,7 +1086,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
|
|
||||||
When("Disable blocking is called with wrong group name", func() {
|
When("Disable blocking is called with wrong group name", func() {
|
||||||
It("should fail", func() {
|
It("should fail", func() {
|
||||||
err := sut.DisableBlocking(500*time.Millisecond, []string{"unknownGroupName"})
|
err := sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{"unknownGroupName"})
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -1094,7 +1094,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
When("Blocking status is called", func() {
|
When("Blocking status is called", func() {
|
||||||
It("should return correct status", func() {
|
It("should return correct status", func() {
|
||||||
By("enable blocking via API", func() {
|
By("enable blocking via API", func() {
|
||||||
sut.EnableBlocking()
|
sut.EnableBlocking(context.TODO())
|
||||||
})
|
})
|
||||||
|
|
||||||
By("Query blocking status via API should return 'enabled'", func() {
|
By("Query blocking status via API should return 'enabled'", func() {
|
||||||
|
@ -1103,7 +1103,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("disable blocking via API", func() {
|
By("disable blocking via API", func() {
|
||||||
err := sut.DisableBlocking(500*time.Millisecond, []string{})
|
err := sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -1149,12 +1149,12 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
var rcfg config.RedisConfig
|
var rcfg config.Redis
|
||||||
err = defaults.Set(&rcfg)
|
err = defaults.Set(&rcfg)
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
rcfg.Address = redisServer.Addr()
|
rcfg.Address = redisServer.Addr()
|
||||||
redisClient, err = redis.New(&rcfg)
|
redisClient, err = redis.New(context.TODO(), &rcfg)
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(redisClient).ShouldNot(BeNil())
|
Expect(redisClient).ShouldNot(BeNil())
|
||||||
|
@ -1171,7 +1171,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
When("disable", func() {
|
When("disable", func() {
|
||||||
It("should return disable", func() {
|
It("should return disable", func() {
|
||||||
sut.EnableBlocking()
|
sut.EnableBlocking(context.TODO())
|
||||||
|
|
||||||
redisMockMsg := &redis.EnabledMessage{
|
redisMockMsg := &redis.EnabledMessage{
|
||||||
State: false,
|
State: false,
|
||||||
|
@ -1185,7 +1185,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
When("disable", func() {
|
When("disable", func() {
|
||||||
It("should return disable", func() {
|
It("should return disable", func() {
|
||||||
sut.EnableBlocking()
|
sut.EnableBlocking(context.TODO())
|
||||||
redisMockMsg := &redis.EnabledMessage{
|
redisMockMsg := &redis.EnabledMessage{
|
||||||
State: false,
|
State: false,
|
||||||
Groups: []string{"unknown"},
|
Groups: []string{"unknown"},
|
||||||
|
@ -1199,7 +1199,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
When("enable", func() {
|
When("enable", func() {
|
||||||
It("should return enable", func() {
|
It("should return enable", func() {
|
||||||
err = sut.DisableBlocking(time.Hour, []string{})
|
err = sut.DisableBlocking(context.TODO(), time.Hour, []string{})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
redisMockMsg := &redis.EnabledMessage{
|
redisMockMsg := &redis.EnabledMessage{
|
||||||
|
|
|
@ -60,7 +60,7 @@ func newCachingResolver(ctx context.Context,
|
||||||
|
|
||||||
if c.redisClient != nil {
|
if c.redisClient != nil {
|
||||||
go c.redisSubscriber(ctx)
|
go c.redisSubscriber(ctx)
|
||||||
c.redisClient.GetRedisCache()
|
c.redisClient.GetRedisCache(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
|
@ -733,7 +733,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
var (
|
var (
|
||||||
redisServer *miniredis.Miniredis
|
redisServer *miniredis.Miniredis
|
||||||
redisClient *redis.Client
|
redisClient *redis.Client
|
||||||
redisConfig *config.RedisConfig
|
redisConfig *config.Redis
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -741,14 +741,14 @@ var _ = Describe("CachingResolver", func() {
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
var rcfg config.RedisConfig
|
var rcfg config.Redis
|
||||||
err = defaults.Set(&rcfg)
|
err = defaults.Set(&rcfg)
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
rcfg.Address = redisServer.Addr()
|
rcfg.Address = redisServer.Addr()
|
||||||
redisConfig = &rcfg
|
redisConfig = &rcfg
|
||||||
redisClient, err = redis.New(redisConfig)
|
redisClient, err = redis.New(context.TODO(), redisConfig)
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(redisClient).ShouldNot(BeNil())
|
Expect(redisClient).ShouldNot(BeNil())
|
||||||
|
|
|
@ -452,14 +452,18 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
|
|
||||||
Describe("Weighted random on resolver selection", func() {
|
Describe("Weighted random on resolver selection", func() {
|
||||||
When("4 upstream resolvers are defined", func() {
|
When("4 upstream resolvers are defined", func() {
|
||||||
|
var (
|
||||||
|
mockUpstream1 *MockUDPUpstreamServer
|
||||||
|
mockUpstream2 *MockUDPUpstreamServer
|
||||||
|
)
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
withError1 := config.Upstream{Host: "wrong1"}
|
withError1 := config.Upstream{Host: "wrong1"}
|
||||||
withError2 := config.Upstream{Host: "wrong2"}
|
withError2 := config.Upstream{Host: "wrong2"}
|
||||||
|
|
||||||
mockUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
mockUpstream1 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||||
DeferCleanup(mockUpstream1.Close)
|
DeferCleanup(mockUpstream1.Close)
|
||||||
|
|
||||||
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
mockUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||||
DeferCleanup(mockUpstream2.Close)
|
DeferCleanup(mockUpstream2.Close)
|
||||||
|
|
||||||
upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}
|
upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}
|
||||||
|
|
|
@ -31,6 +31,10 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancelFn context.CancelFunc
|
cancelFn context.CancelFunc
|
||||||
|
timeout = 2 * time.Second
|
||||||
|
|
||||||
|
testUpstream1 *MockUDPUpstreamServer
|
||||||
|
testUpstream2 *MockUDPUpstreamServer
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -58,6 +62,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
upstreamsCfg.StartVerify = sutVerify
|
upstreamsCfg.StartVerify = sutVerify
|
||||||
|
|
||||||
sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams)
|
sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams)
|
||||||
|
sutConfig.Timeout = config.Duration(timeout)
|
||||||
sut, err = NewStrictResolver(ctx, sutConfig, bootstrap)
|
sut, err = NewStrictResolver(ctx, sutConfig, bootstrap)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -143,10 +148,10 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
When("Both are responding", func() {
|
When("Both are responding", func() {
|
||||||
When("they respond in time", func() {
|
When("they respond in time", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
testUpstream1 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||||
DeferCleanup(testUpstream1.Close)
|
DeferCleanup(testUpstream1.Close)
|
||||||
|
|
||||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
|
testUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
|
||||||
DeferCleanup(testUpstream2.Close)
|
DeferCleanup(testUpstream2.Close)
|
||||||
|
|
||||||
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
|
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
|
||||||
|
@ -165,8 +170,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
})
|
})
|
||||||
When("first upstream exceeds upstreamTimeout", func() {
|
When("first upstream exceeds upstreamTimeout", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
timeout := sut.cfg.Timeout.ToDuration()
|
testUpstream1 = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
|
||||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
|
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
|
||||||
time.Sleep(2 * timeout)
|
time.Sleep(2 * timeout)
|
||||||
|
|
||||||
|
@ -176,7 +180,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
})
|
})
|
||||||
DeferCleanup(testUpstream1.Close)
|
DeferCleanup(testUpstream1.Close)
|
||||||
|
|
||||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
|
testUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
|
||||||
DeferCleanup(testUpstream2.Close)
|
DeferCleanup(testUpstream2.Close)
|
||||||
|
|
||||||
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
|
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
|
||||||
|
@ -193,9 +197,8 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
When("all upstreams exceed upsteamTimeout", func() {
|
When("all upstreams exceed upsteamTimeout", func() {
|
||||||
BeforeEach(func() {
|
JustBeforeEach(func() {
|
||||||
timeout := sut.cfg.Timeout.ToDuration()
|
testUpstream1 = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
|
||||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
|
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
|
||||||
time.Sleep(2 * timeout)
|
time.Sleep(2 * timeout)
|
||||||
|
|
||||||
|
@ -205,7 +208,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
})
|
})
|
||||||
DeferCleanup(testUpstream1.Close)
|
DeferCleanup(testUpstream1.Close)
|
||||||
|
|
||||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
testUpstream2 = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2")
|
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2")
|
||||||
time.Sleep(2 * timeout)
|
time.Sleep(2 * timeout)
|
||||||
|
|
||||||
|
@ -225,12 +228,10 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
})
|
})
|
||||||
When("Only second is working", func() {
|
When("Only second is working", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
testUpstream1 := config.Upstream{Host: "wrong"}
|
testUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
|
||||||
|
|
||||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123")
|
|
||||||
DeferCleanup(testUpstream2.Close)
|
DeferCleanup(testUpstream2.Close)
|
||||||
|
|
||||||
upstreams = []config.Upstream{testUpstream1, testUpstream2.Start()}
|
upstreams = []config.Upstream{{Host: "wrong"}, testUpstream2.Start()}
|
||||||
})
|
})
|
||||||
It("Should use result from second one", func() {
|
It("Should use result from second one", func() {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
|
|
|
@ -135,9 +135,12 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
redisClient, redisErr := redis.New(&cfg.Redis)
|
var redisClient *redis.Client
|
||||||
if redisErr != nil && cfg.Redis.Required {
|
if cfg.Redis.IsEnabled() {
|
||||||
return nil, redisErr
|
redisClient, err = redis.New(ctx, &cfg.Redis)
|
||||||
|
if err != nil && cfg.Redis.Required {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
queryResolver, queryError := createQueryResolver(ctx, cfg, bootstrap, redisClient)
|
queryResolver, queryError := createQueryResolver(ctx, cfg, bootstrap, redisClient)
|
||||||
|
@ -423,6 +426,11 @@ func (s *Server) registerDNSHandlers() {
|
||||||
func (s *Server) printConfiguration() {
|
func (s *Server) printConfiguration() {
|
||||||
logger().Info("current configuration:")
|
logger().Info("current configuration:")
|
||||||
|
|
||||||
|
if s.cfg.Redis.IsEnabled() {
|
||||||
|
logger().Info("Redis:")
|
||||||
|
log.WithIndent(logger(), " ", s.cfg.Redis.LogConfig)
|
||||||
|
}
|
||||||
|
|
||||||
resolver.ForEach(s.queryResolver, func(res resolver.Resolver) {
|
resolver.ForEach(s.queryResolver, func(res resolver.Resolver) {
|
||||||
resolver.LogResolverConfig(res, logger())
|
resolver.LogResolverConfig(res, logger())
|
||||||
})
|
})
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// CtxSend sends a value to a channel while the context isn't done.
|
||||||
|
// If the message is sent, it returns true.
|
||||||
|
// If the context is done or the channel is closed, it returns false.
|
||||||
|
func CtxSend[T any](ctx context.Context, ch chan T, val T) (ok bool) {
|
||||||
|
if ctx == nil || ch == nil || ctx.Err() != nil {
|
||||||
|
ok = false
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
ok = false
|
||||||
|
case ch <- val:
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,125 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testMessage = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Context utils", func() {
|
||||||
|
Describe("CtxSend", func() {
|
||||||
|
var ch chan int
|
||||||
|
BeforeEach(func() {
|
||||||
|
ch = make(chan int, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
ch = nil
|
||||||
|
})
|
||||||
|
|
||||||
|
When("channel is not closed", func() {
|
||||||
|
It("should send value to channel", func(ctx context.Context) {
|
||||||
|
go startReader(ctx, ch)
|
||||||
|
Expect(CtxSend(ctx, ch, testMessage)).Should(BeTrue())
|
||||||
|
}, SpecTimeout(time.Second))
|
||||||
|
})
|
||||||
|
|
||||||
|
When("channel is closed", func() {
|
||||||
|
It("should return false", func(ctx context.Context) {
|
||||||
|
go startReader(ctx, ch)
|
||||||
|
close(ch)
|
||||||
|
Expect(CtxSend(ctx, ch, testMessage)).Should(BeFalse())
|
||||||
|
}, SpecTimeout(time.Second))
|
||||||
|
})
|
||||||
|
|
||||||
|
When("channel is nil", func() {
|
||||||
|
It("should return false", func(ctx context.Context) {
|
||||||
|
Expect(CtxSend(ctx, nil, testMessage)).Should(BeFalse())
|
||||||
|
}, SpecTimeout(time.Second))
|
||||||
|
})
|
||||||
|
|
||||||
|
When("channel is full", func() {
|
||||||
|
It("should wait", func(ctx context.Context) {
|
||||||
|
ch <- testMessage
|
||||||
|
go func(ctx context.Context, ch chan int) {
|
||||||
|
timer := time.NewTimer(time.Millisecond * 200)
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
startReader(ctx, ch)
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}(ctx, ch)
|
||||||
|
Expect(CtxSend(ctx, ch, testMessage)).Should(BeTrue())
|
||||||
|
}, SpecTimeout(time.Second))
|
||||||
|
})
|
||||||
|
|
||||||
|
When("context is done", func() {
|
||||||
|
It("should return false", func(ctx context.Context) {
|
||||||
|
go startReader(ctx, ch)
|
||||||
|
cCtx, cancel := context.WithCancel(ctx)
|
||||||
|
cancel()
|
||||||
|
<-cCtx.Done()
|
||||||
|
Expect(CtxSend(cCtx, ch, testMessage)).Should(BeFalse())
|
||||||
|
}, SpecTimeout(time.Second))
|
||||||
|
})
|
||||||
|
|
||||||
|
When("context is terminated", func() {
|
||||||
|
It("should return false", func(ctx context.Context) {
|
||||||
|
ch <- testMessage
|
||||||
|
cCtx, cancel := context.WithCancel(ctx)
|
||||||
|
go func(ctx context.Context, cancel context.CancelFunc) {
|
||||||
|
timer := time.NewTimer(time.Millisecond * 200)
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
cancel()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}(ctx, cancel)
|
||||||
|
Expect(CtxSend(cCtx, ch, testMessage)).Should(BeFalse())
|
||||||
|
}, SpecTimeout(time.Second))
|
||||||
|
})
|
||||||
|
|
||||||
|
When("context is nil", func() {
|
||||||
|
It("should return false", func(ctx context.Context) {
|
||||||
|
Expect(CtxSend(nil, ch, testMessage)).Should(BeFalse())
|
||||||
|
}, SpecTimeout(time.Second))
|
||||||
|
})
|
||||||
|
|
||||||
|
When("context and channel are nil", func() {
|
||||||
|
It("should return false", func(ctx context.Context) {
|
||||||
|
Expect(CtxSend(nil, nil, testMessage)).Should(BeFalse())
|
||||||
|
}, SpecTimeout(time.Second))
|
||||||
|
})
|
||||||
|
|
||||||
|
When("context is done and channel is closed", func() {
|
||||||
|
It("should return false", func(ctx context.Context) {
|
||||||
|
go startReader(ctx, ch)
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
cancel()
|
||||||
|
close(ch)
|
||||||
|
Expect(CtxSend(ctx, ch, testMessage)).Should(BeFalse())
|
||||||
|
}, SpecTimeout(time.Second))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
func startReader(ctx context.Context, ch <-chan int) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case i, ok := <-ch:
|
||||||
|
if ok {
|
||||||
|
Expect(i).Should(Equal(testMessage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue