diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 653d4034..86d175a7 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -31,19 +31,13 @@ "GitHub.vscode-github-actions" ], "settings": { - "go.lintFlags": ["--config=${containerWorkspaceFolder}/.golangci.yml"], + "go.lintFlags": [ + "--config=${containerWorkspaceFolder}/.golangci.yml", + "--fast" + ], "go.alternateTools": { "go-langserver": "gopls" }, - "[go]": { - "editor.defaultFormatter": "golang.go" - }, - "[json][jsonc][github-actions-workflow]": { - "editor.defaultFormatter": "esbenp.prettier-vscode" - }, - "[markdown]": { - "editor.defaultFormatter": "yzhang.markdown-all-in-one" - }, "markiscodecoverage.searchCriteria": "**/*.lcov", "runItOn": { "commands": [ @@ -52,6 +46,15 @@ "cmd": "${workspaceRoot}/.devcontainer/scripts/runItOnGo.sh ${fileDirname} ${workspaceRoot}" } ] + }, + "[go]": { + "editor.defaultFormatter": "golang.go" + }, + "[json][jsonc][github-actions-workflow]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[markdown]": { + "editor.defaultFormatter": "yzhang.markdown-all-in-one" } } } diff --git a/.golangci.yml b/.golangci.yml index 651bf0b7..5f7573ad 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -98,3 +98,7 @@ issues: - gochecknoinits - gochecknoglobals - gosec + - path: _test\.go + linters: + - staticcheck + text: "SA1012:" diff --git a/.vscode/settings.json b/.vscode/settings.json index a1acd12e..4f04c36b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,6 +8,7 @@ "source.organizeImports": true, "source.fixAll": true }, + "editor.rulers": [120], "go.showWelcome": false, "go.survey.prompt": false, "go.useLanguageServer": true, diff --git a/api/api_interface_impl.go b/api/api_interface_impl.go index 94043b99..109c9c5d 100644 --- a/api/api_interface_impl.go +++ b/api/api_interface_impl.go @@ -29,8 +29,8 @@ type BlockingStatus struct { // BlockingControl interface to control the blocking status type BlockingControl interface { - EnableBlocking() - DisableBlocking(duration time.Duration, disableGroups []string) error + EnableBlocking(ctx context.Context) + DisableBlocking(ctx context.Context, duration time.Duration, disableGroups []string) error BlockingStatus() BlockingStatus } @@ -71,7 +71,7 @@ func NewOpenAPIInterfaceImpl(control BlockingControl, } } -func (i *OpenAPIInterfaceImpl) DisableBlocking(_ context.Context, +func (i *OpenAPIInterfaceImpl) DisableBlocking(ctx context.Context, request DisableBlockingRequestObject, ) (DisableBlockingResponseObject, error) { var ( @@ -91,7 +91,7 @@ func (i *OpenAPIInterfaceImpl) DisableBlocking(_ context.Context, groups = strings.Split(*request.Params.Groups, ",") } - err = i.control.DisableBlocking(duration, groups) + err = i.control.DisableBlocking(ctx, duration, groups) if err != nil { return DisableBlocking400TextResponse(log.EscapeInput(err.Error())), nil @@ -100,9 +100,9 @@ func (i *OpenAPIInterfaceImpl) DisableBlocking(_ context.Context, return DisableBlocking200Response{}, nil } -func (i *OpenAPIInterfaceImpl) EnableBlocking(_ context.Context, _ EnableBlockingRequestObject, +func (i *OpenAPIInterfaceImpl) EnableBlocking(ctx context.Context, _ EnableBlockingRequestObject, ) (EnableBlockingResponseObject, error) { - i.control.EnableBlocking() + i.control.EnableBlocking(ctx) return EnableBlocking200Response{}, nil } diff --git a/api/api_interface_impl_test.go b/api/api_interface_impl_test.go index 3d3c8a60..9115bc69 100644 --- a/api/api_interface_impl_test.go +++ b/api/api_interface_impl_test.go @@ -37,11 +37,11 @@ func (m *ListRefreshMock) RefreshLists() error { return args.Error(0) } -func (m *BlockingControlMock) EnableBlocking() { +func (m *BlockingControlMock) EnableBlocking(_ context.Context) { _ = m.Called() } -func (m *BlockingControlMock) DisableBlocking(t time.Duration, g []string) error { +func (m *BlockingControlMock) DisableBlocking(_ context.Context, t time.Duration, g []string) error { args := m.Called(t, g) return args.Error(0) diff --git a/config/config.go b/config/config.go index 5eb7698b..20661712 100644 --- a/config/config.go +++ b/config/config.go @@ -219,7 +219,7 @@ type Config struct { Caching CachingConfig `yaml:"caching"` QueryLog QueryLogConfig `yaml:"queryLog"` Prometheus MetricsConfig `yaml:"prometheus"` - Redis RedisConfig `yaml:"redis"` + Redis Redis `yaml:"redis"` Log log.Config `yaml:"log"` Ports PortsConfig `yaml:"ports"` MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"` @@ -280,20 +280,6 @@ type ( } ) -// RedisConfig configuration for the redis connection -type RedisConfig struct { - Address string `yaml:"address"` - Username string `yaml:"username" default:""` - Password string `yaml:"password" default:""` - Database int `yaml:"database" default:"0"` - Required bool `yaml:"required" default:"false"` - ConnectionAttempts int `yaml:"connectionAttempts" default:"3"` - ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"` - SentinelUsername string `yaml:"sentinelUsername" default:""` - SentinelPassword string `yaml:"sentinelPassword" default:""` - SentinelAddresses []string `yaml:"sentinelAddresses"` -} - type ( FQDNOnly = toEnable EDE = toEnable diff --git a/config/config_test.go b/config/config_test.go index f52a92d1..8046d347 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -169,7 +169,7 @@ var _ = Describe("Config", func() { err = writeConfigDir(tmpDir) Expect(err).Should(Succeed()) - _, err := LoadConfig(tmpDir.Path, true) + c, err = LoadConfig(tmpDir.Path, true) Expect(err).Should(Succeed()) defaultTestFileConfig(c) diff --git a/config/redis.go b/config/redis.go new file mode 100644 index 00000000..79f76e9c --- /dev/null +++ b/config/redis.go @@ -0,0 +1,57 @@ +package config + +import ( + "strings" + + "github.com/sirupsen/logrus" +) + +// Redis configuration for the redis connection +type Redis struct { + Address string `yaml:"address"` + Username string `yaml:"username" default:""` + Password string `yaml:"password" default:""` + Database int `yaml:"database" default:"0"` + Required bool `yaml:"required" default:"false"` + ConnectionAttempts int `yaml:"connectionAttempts" default:"3"` + ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"` + SentinelUsername string `yaml:"sentinelUsername" default:""` + SentinelPassword string `yaml:"sentinelPassword" default:""` + SentinelAddresses []string `yaml:"sentinelAddresses"` +} + +// IsEnabled implements `config.Configurable` +func (c *Redis) IsEnabled() bool { + return c.Address != "" +} + +// LogConfig implements `config.Configurable` +func (c *Redis) LogConfig(logger *logrus.Entry) { + if len(c.SentinelAddresses) == 0 { + logger.Info("address: ", c.Address) + } + + logger.Info("username: ", c.Username) + logger.Info("password: ", obfuscatePassword(c.Password)) + logger.Info("database: ", c.Database) + logger.Info("required: ", c.Required) + logger.Info("connectionAttempts: ", c.ConnectionAttempts) + logger.Info("connectionCooldown: ", c.ConnectionCooldown) + + if len(c.SentinelAddresses) > 0 { + logger.Info("sentinel:") + logger.Info(" master: ", c.Address) + logger.Info(" username: ", c.SentinelUsername) + logger.Info(" password: ", obfuscatePassword(c.SentinelPassword)) + logger.Info(" addresses:") + + for _, addr := range c.SentinelAddresses { + logger.Info(" - ", addr) + } + } +} + +// obfuscatePassword replaces all characters of a password except the first and last with * +func obfuscatePassword(pass string) string { + return strings.Repeat("*", len(pass)) +} diff --git a/config/redis_test.go b/config/redis_test.go new file mode 100644 index 00000000..639c899e --- /dev/null +++ b/config/redis_test.go @@ -0,0 +1,104 @@ +package config + +import ( + "github.com/0xERR0R/blocky/log" + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Redis", func() { + var ( + c Redis + err error + ) + + suiteBeforeEach() + + BeforeEach(func() { + err = defaults.Set(&c) + Expect(err).Should(Succeed()) + }) + + Describe("IsEnabled", func() { + When("all fields are default", func() { + It("should be disabled", func() { + Expect(c.IsEnabled()).Should(BeFalse()) + }) + }) + + When("Address is set", func() { + BeforeEach(func() { + c.Address = "localhost:6379" + }) + + It("should be enabled", func() { + Expect(c.IsEnabled()).Should(BeTrue()) + }) + }) + }) + + Describe("LogConfig", func() { + BeforeEach(func() { + logger, hook = log.NewMockEntry() + }) + + When("all fields are default", func() { + It("should log default values", func() { + c.LogConfig(logger) + + Expect(hook.Messages).Should( + SatisfyAll(ContainElement(ContainSubstring("address: ")), + ContainElement(ContainSubstring("username: ")), + ContainElement(ContainSubstring("password: ")), + ContainElement(ContainSubstring("database: ")), + ContainElement(ContainSubstring("required: ")), + ContainElement(ContainSubstring("connectionAttempts: ")), + ContainElement(ContainSubstring("connectionCooldown: ")))) + }) + }) + + When("Address is set", func() { + BeforeEach(func() { + c.Address = "localhost:6379" + }) + + It("should log address", func() { + c.LogConfig(logger) + + Expect(hook.Messages).Should(ContainElement(ContainSubstring("address: localhost:6379"))) + }) + }) + + When("SentinelAddresses is set", func() { + BeforeEach(func() { + c.SentinelAddresses = []string{"localhost:26379", "localhost:26380"} + }) + + It("should log sentinel addresses", func() { + c.LogConfig(logger) + + Expect(hook.Messages).Should( + SatisfyAll( + ContainElement(ContainSubstring("sentinel:")), + ContainElement(ContainSubstring(" addresses:")), + ContainElement(ContainSubstring(" - localhost:26379")), + ContainElement(ContainSubstring(" - localhost:26380")))) + }) + }) + }) + + Describe("obfuscatePassword", func() { + When("password is empty", func() { + It("should return empty string", func() { + Expect(obfuscatePassword("")).Should(Equal("")) + }) + }) + + When("password is not empty", func() { + It("should return obfuscated password", func() { + Expect(obfuscatePassword("test123")).Should(Equal("*******")) + }) + }) + }) +}) diff --git a/redis/redis.go b/redis/redis.go index ae208da7..7c205f19 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -12,6 +12,7 @@ import ( "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" + "github.com/0xERR0R/blocky/util" "github.com/go-redis/redis/v8" "github.com/google/uuid" "github.com/miekg/dns" @@ -56,10 +57,9 @@ type EnabledMessage struct { // Client for redis communication type Client struct { - config *config.RedisConfig + config *config.Redis client *redis.Client l *logrus.Entry - ctx context.Context id []byte sendBuffer chan *bufferMessage CacheChannel chan *CacheMessage @@ -67,15 +67,15 @@ type Client struct { } // New creates a new redis client -func New(cfg *config.RedisConfig) (*Client, error) { +func New(ctx context.Context, cfg *config.Redis) (*Client, error) { // disable redis if no address is provided if cfg == nil || len(cfg.Address) == 0 { return nil, nil //nolint:nilnil } - var rdb *redis.Client + var baseClient *redis.Client if len(cfg.SentinelAddresses) > 0 { - rdb = redis.NewFailoverClient(&redis.FailoverOptions{ + baseClient = redis.NewFailoverClient(&redis.FailoverOptions{ MasterName: cfg.Address, SentinelUsername: cfg.Username, SentinelPassword: cfg.SentinelPassword, @@ -87,7 +87,7 @@ func New(cfg *config.RedisConfig) (*Client, error) { MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(), }) } else { - rdb = redis.NewClient(&redis.Options{ + baseClient = redis.NewClient(&redis.Options{ Addr: cfg.Address, Username: cfg.Username, Password: cfg.Password, @@ -97,7 +97,7 @@ func New(cfg *config.RedisConfig) (*Client, error) { }) } - ctx := context.Background() + rdb := baseClient.WithContext(ctx) _, err := rdb.Ping(ctx).Result() if err == nil { @@ -110,7 +110,6 @@ func New(cfg *config.RedisConfig) (*Client, error) { config: cfg, client: rdb, l: log.PrefixedLog("redis"), - ctx: ctx, id: id, sendBuffer: make(chan *bufferMessage, chanCap), CacheChannel: make(chan *CacheMessage, chanCap), @@ -118,7 +117,7 @@ func New(cfg *config.RedisConfig) (*Client, error) { } // start channel handling go routine - err = res.startup() + err = res.startup(ctx) return res, err } @@ -137,7 +136,7 @@ func (c *Client) PublishCache(key string, message *dns.Msg) { } } -func (c *Client) PublishEnabled(state *EnabledMessage) { +func (c *Client) PublishEnabled(ctx context.Context, state *EnabledMessage) { binState, sErr := json.Marshal(state) if sErr == nil { binMsg, mErr := json.Marshal(redisMessage{ @@ -147,22 +146,30 @@ func (c *Client) PublishEnabled(state *EnabledMessage) { }) if mErr == nil { - c.client.Publish(c.ctx, SyncChannelName, binMsg) + c.client.Publish(ctx, SyncChannelName, binMsg) } } } // GetRedisCache reads the redis cache and publish it to the channel -func (c *Client) GetRedisCache() { +func (c *Client) GetRedisCache(ctx context.Context) { c.l.Debug("GetRedisCache") go func() { - iter := c.client.Scan(c.ctx, 0, prefixKey("*"), 0).Iterator() - for iter.Next(c.ctx) { - response, err := c.getResponse(iter.Val()) + iter := c.client.Scan(ctx, 0, prefixKey("*"), 0).Iterator() + if err := iter.Err(); err != nil { + c.l.Error("GetRedisCache ", err) + + return + } + + for iter.Next(ctx) { + response, err := c.getResponse(ctx, iter.Val()) if err == nil { if response != nil { - c.CacheChannel <- response + if !util.CtxSend(ctx, c.CacheChannel, response) { + return + } } } else { c.l.Error("GetRedisCache ", err) @@ -172,10 +179,10 @@ func (c *Client) GetRedisCache() { } // startup starts a new goroutine for subscription and translation -func (c *Client) startup() error { - ps := c.client.Subscribe(c.ctx, SyncChannelName) +func (c *Client) startup(ctx context.Context) error { + ps := c.client.Subscribe(ctx, SyncChannelName) - _, err := ps.Receive(c.ctx) + _, err := ps.Receive(ctx) if err == nil { go func() { for { @@ -186,11 +193,16 @@ func (c *Client) startup() error { if msg != nil && len(msg.Payload) > 0 { // message is not empty - c.processReceivedMessage(msg) + c.processReceivedMessage(ctx, msg) } // publish message from buffer case s := <-c.sendBuffer: - c.publishMessageFromBuffer(s) + c.publishMessageFromBuffer(ctx, s) + // context is done + case <-ctx.Done(): + c.client.Close() + + return } } }() @@ -199,7 +211,7 @@ func (c *Client) startup() error { return err } -func (c *Client) publishMessageFromBuffer(s *bufferMessage) { +func (c *Client) publishMessageFromBuffer(ctx context.Context, s *bufferMessage) { origRes := s.Message origRes.Compress = true binRes, pErr := origRes.Pack() @@ -213,61 +225,61 @@ func (c *Client) publishMessageFromBuffer(s *bufferMessage) { }) if mErr == nil { - c.client.Publish(c.ctx, SyncChannelName, binMsg) + c.client.Publish(ctx, SyncChannelName, binMsg) } - c.client.Set(c.ctx, + c.client.Set(ctx, prefixKey(s.Key), binRes, c.getTTL(origRes)) } } -func (c *Client) processReceivedMessage(msg *redis.Message) { +func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) { var rm redisMessage - err := json.Unmarshal([]byte(msg.Payload), &rm) - if err == nil { - // message was sent from a different blocky instance - if !bytes.Equal(rm.Client, c.id) { - switch rm.Type { - case messageTypeCache: - var cm *CacheMessage + if err := json.Unmarshal([]byte(msg.Payload), &rm); err != nil { + c.l.Error("Processing error: ", err) - cm, err = convertMessage(&rm, 0) - if err == nil { - c.CacheChannel <- cm - } - case messageTypeEnable: - err = c.processEnabledMessage(&rm) - default: - c.l.Warn("Unknown message type: ", rm.Type) + return + } + + // message was sent from a different blocky instance + if !bytes.Equal(rm.Client, c.id) { + switch rm.Type { + case messageTypeCache: + var cm *CacheMessage + + cm, err := convertMessage(&rm, 0) + if err != nil { + c.l.Error("Processing CacheMessage error: ", err) + + return } + + util.CtxSend(ctx, c.CacheChannel, cm) + case messageTypeEnable: + var msg EnabledMessage + + if err := json.Unmarshal(rm.Message, &msg); err != nil { + c.l.Error("Processing EnabledMessage error: ", err) + + return + } + + util.CtxSend(ctx, c.EnabledChannel, &msg) + default: + c.l.Warn("Unknown message type: ", rm.Type) } } - - if err != nil { - c.l.Error("Processing error: ", err) - } -} - -func (c *Client) processEnabledMessage(redisMsg *redisMessage) error { - var msg EnabledMessage - - err := json.Unmarshal(redisMsg.Message, &msg) - if err == nil { - c.EnabledChannel <- &msg - } - - return err } // getResponse returns model.Response for a key -func (c *Client) getResponse(key string) (*CacheMessage, error) { - resp, err := c.client.Get(c.ctx, key).Result() +func (c *Client) getResponse(ctx context.Context, key string) (*CacheMessage, error) { + resp, err := c.client.Get(ctx, key).Result() if err == nil { var ttl time.Duration - ttl, err = c.client.TTL(c.ctx, key).Result() + ttl, err = c.client.TTL(ctx, key).Result() if err == nil { var result *CacheMessage diff --git a/redis/redis_test.go b/redis/redis_test.go index 04efb6b7..5c7f530a 100644 --- a/redis/redis_test.go +++ b/redis/redis_test.go @@ -1,6 +1,7 @@ package redis import ( + "context" "encoding/json" "time" @@ -18,76 +19,75 @@ const ( exampleComKey = CacheStorePrefix + "example.com" ) -var ( - redisServer *miniredis.Miniredis - redisClient *Client - redisConfig *config.RedisConfig - err error -) - var _ = Describe("Redis client", func() { + var ( + redisConfig *config.Redis + + redisClient *Client + + err error + ) + BeforeEach(func() { - redisServer, err = miniredis.Run() - - Expect(err).Should(Succeed()) - - DeferCleanup(redisServer.Close) - - var rcfg config.RedisConfig - err = defaults.Set(&rcfg) - - Expect(err).Should(Succeed()) - - rcfg.Address = redisServer.Addr() + var rcfg config.Redis + Expect(defaults.Set(&rcfg)).Should(Succeed()) redisConfig = &rcfg - redisClient, err = New(redisConfig) - - Expect(err).Should(Succeed()) - Expect(redisClient).ShouldNot(BeNil()) }) + Describe("Client creation", func() { When("redis configuration has no address", func() { - It("should return nil without error", func() { - var rcfg config.RedisConfig - err = defaults.Set(&rcfg) - - Expect(err).Should(Succeed()) - - Expect(New(&rcfg)).Should(BeNil()) + It("should return nil without error", func(ctx context.Context) { + Expect(New(ctx, redisConfig)).Should(BeNil()) }) }) + When("redis configuration has invalid address", func() { - It("should fail with error", func() { - var rcfg config.RedisConfig - err = defaults.Set(&rcfg) - Expect(err).Should(Succeed()) - - rcfg.Address = "127.0.0.1:0" - - _, err = New(&rcfg) + BeforeEach(func() { + redisConfig.Address = "127.0.0.1:0" + }) + It("should fail with error", func(ctx context.Context) { + _, err = New(ctx, redisConfig) Expect(err).Should(HaveOccurred()) }) }) + + When("sentinel is enabled without servers", func() { + BeforeEach(func() { + redisConfig.Address = "test" + redisConfig.SentinelAddresses = []string{"127.0.0.1:0"} + }) + + It("should fail with error", func(ctx context.Context) { + _, err = New(ctx, redisConfig) + Expect(err).Should(HaveOccurred()) + }) + }) + When("redis configuration has invalid password", func() { - It("should fail with error", func() { - var rcfg config.RedisConfig - err = defaults.Set(&rcfg) - Expect(err).Should(Succeed()) - - rcfg.Address = redisServer.Addr() - rcfg.Password = "wrong" - - _, err = New(&rcfg) + BeforeEach(func() { + setupRedisServer(redisConfig) + redisConfig.Password = "wrong" + }) + It("should fail with error", func(ctx context.Context) { + _, err = New(ctx, redisConfig) Expect(err).Should(HaveOccurred()) }) }) }) Describe("Publish message", func() { + var redisServer *miniredis.Miniredis + BeforeEach(func() { + redisServer = setupRedisServer(redisConfig) + }) + When("Redis client publishes 'cache' message", func() { - It("One new entry with TTL > 0 should be persisted in the database", func() { + It("One new entry with TTL > 0 should be persisted in the database", func(ctx context.Context) { + redisClient, err = New(ctx, redisConfig) + Expect(err).Should(Succeed()) + By("Database is empty", func() { Eventually(func() []string { return redisServer.DB(redisConfig.Database).Keys() @@ -112,7 +112,10 @@ var _ = Describe("Redis client", func() { }) }) - It("One new entry with default TTL should be persisted in the database", func() { + It("One new entry with default TTL should be persisted in the database", func(ctx context.Context) { + redisClient, err = New(ctx, redisConfig) + Expect(err).Should(Succeed()) + By("Database is empty", func() { Eventually(func() []string { return redisServer.DB(redisConfig.Database).Keys() @@ -138,20 +141,30 @@ var _ = Describe("Redis client", func() { }) }) When("Redis client publishes 'enabled' message", func() { - It("should propagate the message over redis", func() { - redisClient.PublishEnabled(&EnabledMessage{ + It("should propagate the message over redis", func(ctx context.Context) { + redisClient, err = New(ctx, redisConfig) + Expect(err).Should(Succeed()) + + redisClient.PublishEnabled(ctx, &EnabledMessage{ State: true, }) Eventually(func() map[string]int { return redisServer.PubSubNumSub(SyncChannelName) }).Should(HaveLen(1)) - }) + }, SpecTimeout(time.Second*6)) }) }) Describe("Receive message", func() { + var redisServer *miniredis.Miniredis + BeforeEach(func() { + redisServer = setupRedisServer(redisConfig) + }) When("'enabled' message is received", func() { - It("should propagate the message over the channel", func() { + It("should propagate the message over the channel", func(ctx context.Context) { + redisClient, err = New(ctx, redisConfig) + Expect(err).Should(Succeed()) + var binState []byte binState, err = json.Marshal(EnabledMessage{State: true}) Expect(err).Should(Succeed()) @@ -179,7 +192,10 @@ var _ = Describe("Redis client", func() { }) }) When("'cache' message is received", func() { - It("should propagate the message over the channel", func() { + It("should propagate the message over the channel", func(ctx context.Context) { + redisClient, err = New(ctx, redisConfig) + Expect(err).Should(Succeed()) + res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123") Expect(err).Should(Succeed()) @@ -209,10 +225,13 @@ var _ = Describe("Redis client", func() { Eventually(func() chan *CacheMessage { return redisClient.CacheChannel }).Should(HaveLen(lenE + 1)) - }) + }, SpecTimeout(time.Second*6)) }) When("wrong data is received", func() { - It("should not propagate the message over the channel if data is wrong", func() { + It("should not propagate the message over the channel if data is wrong", func(ctx context.Context) { + redisClient, err = New(ctx, redisConfig) + Expect(err).Should(Succeed()) + var id []byte id, err = uuid.New().MarshalBinary() Expect(err).Should(Succeed()) @@ -239,8 +258,11 @@ var _ = Describe("Redis client", func() { Eventually(func() chan *CacheMessage { return redisClient.CacheChannel }).Should(HaveLen(lenC)) - }) - It("should not propagate the message over the channel if type is wrong", func() { + }, SpecTimeout(time.Second*6)) + It("should not propagate the message over the channel if type is wrong", func(ctx context.Context) { + redisClient, err = New(ctx, redisConfig) + Expect(err).Should(Succeed()) + var id []byte id, err = uuid.New().MarshalBinary() Expect(err).Should(Succeed()) @@ -269,13 +291,20 @@ var _ = Describe("Redis client", func() { Eventually(func() chan *CacheMessage { return redisClient.CacheChannel }).Should(HaveLen(lenC)) - }) + }, SpecTimeout(time.Second*6)) }) }) Describe("Read the redis cache and publish it to the channel", func() { + var redisServer *miniredis.Miniredis + BeforeEach(func() { + redisServer = setupRedisServer(redisConfig) + }) When("GetRedisCache is called with valid database entries", func() { - It("Should read data from Redis and propagate it via cache channel", func() { + It("Should read data from Redis and propagate it via cache channel", func(ctx context.Context) { + redisClient, err = New(ctx, redisConfig) + Expect(err).Should(Succeed()) + By("Database is empty", func() { Eventually(func() []string { return redisServer.DB(redisConfig.Database).Keys() @@ -299,18 +328,30 @@ var _ = Describe("Redis client", func() { }) By("call GetRedisCache - It should read one entry from redis and propagate it via channel", func() { - redisClient.GetRedisCache() + redisClient.GetRedisCache(ctx) Eventually(redisClient.CacheChannel).Should(HaveLen(1)) }) - }) + }, SpecTimeout(time.Second*4)) }) When("GetRedisCache is called and database contains not valid entry", func() { - It("Should do nothing (only log error)", func() { + It("Should do nothing (only log error)", func(ctx context.Context) { + redisClient, err = New(ctx, redisConfig) + Expect(err).Should(Succeed()) + Expect(redisServer.DB(redisConfig.Database).Set(CacheStorePrefix+"test", "test")).Should(Succeed()) - redisClient.GetRedisCache() + redisClient.GetRedisCache(ctx) Consistently(redisClient.CacheChannel).Should(BeEmpty()) - }) + }, SpecTimeout(time.Second*2)) }) }) }) + +func setupRedisServer(cfg *config.Redis) *miniredis.Miniredis { + redisServer, err := miniredis.Run() + Expect(err).Should(Succeed()) + DeferCleanup(redisServer.Close) + cfg.Address = redisServer.Addr() + + return redisServer +} diff --git a/resolver/blocking_resolver.go b/resolver/blocking_resolver.go index 290a0cef..662f3514 100644 --- a/resolver/blocking_resolver.go +++ b/resolver/blocking_resolver.go @@ -183,7 +183,7 @@ func (r *BlockingResolver) redisSubscriber(ctx context.Context) { if em.State { r.internalEnableBlocking() } else { - err := r.internalDisableBlocking(em.Duration, em.Groups) + err := r.internalDisableBlocking(ctx, em.Duration, em.Groups) if err != nil { r.log().Warn("Blocking couldn't be disabled:", err) } @@ -216,11 +216,11 @@ func (r *BlockingResolver) retrieveAllBlockingGroups() []string { } // EnableBlocking enables the blocking against the blacklists -func (r *BlockingResolver) EnableBlocking() { +func (r *BlockingResolver) EnableBlocking(ctx context.Context) { r.internalEnableBlocking() if r.redisClient != nil { - r.redisClient.PublishEnabled(&redis.EnabledMessage{State: true}) + r.redisClient.PublishEnabled(ctx, &redis.EnabledMessage{State: true}) } } @@ -236,10 +236,10 @@ func (r *BlockingResolver) internalEnableBlocking() { } // DisableBlocking deactivates the blocking for a particular duration (or forever if 0). -func (r *BlockingResolver) DisableBlocking(duration time.Duration, disableGroups []string) error { - err := r.internalDisableBlocking(duration, disableGroups) +func (r *BlockingResolver) DisableBlocking(ctx context.Context, duration time.Duration, disableGroups []string) error { + err := r.internalDisableBlocking(ctx, duration, disableGroups) if err == nil && r.redisClient != nil { - r.redisClient.PublishEnabled(&redis.EnabledMessage{ + r.redisClient.PublishEnabled(ctx, &redis.EnabledMessage{ State: false, Duration: duration, Groups: disableGroups, @@ -249,7 +249,9 @@ func (r *BlockingResolver) DisableBlocking(duration time.Duration, disableGroups return err } -func (r *BlockingResolver) internalDisableBlocking(duration time.Duration, disableGroups []string) error { +func (r *BlockingResolver) internalDisableBlocking(ctx context.Context, duration time.Duration, + disableGroups []string, +) error { s := r.status s.lock.Lock() defer s.lock.Unlock() @@ -280,7 +282,7 @@ func (r *BlockingResolver) internalDisableBlocking(duration time.Duration, disab log.Log().Infof("disable blocking for %s for group(s) '%s'", duration, log.EscapeInput(strings.Join(s.disabledGroups, "; "))) s.enableTimer = time.AfterFunc(duration, func() { - r.EnableBlocking() + r.EnableBlocking(ctx) log.Log().Info("blocking enabled again") }) } diff --git a/resolver/blocking_resolver_test.go b/resolver/blocking_resolver_test.go index 6d40f7b1..0b7e772b 100644 --- a/resolver/blocking_resolver_test.go +++ b/resolver/blocking_resolver_test.go @@ -842,7 +842,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("Calling Rest API to deactivate all groups", func() { - err := sut.DisableBlocking(0, []string{}) + err := sut.DisableBlocking(context.TODO(), 0, []string{}) Expect(err).Should(Succeed()) }) @@ -875,7 +875,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("Calling Rest API to deactivate only defaultGroup", func() { - err := sut.DisableBlocking(0, []string{"defaultGroup"}) + err := sut.DisableBlocking(context.TODO(), 0, []string{"defaultGroup"}) Expect(err).Should(Succeed()) }) @@ -935,7 +935,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { enabled <- state }) Expect(err).Should(Succeed()) - err = sut.DisableBlocking(500*time.Millisecond, []string{}) + err = sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{}) Expect(err).Should(Succeed()) Eventually(enabled, "1s").Should(Receive(BeFalse())) }) @@ -1025,7 +1025,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { enabled <- false }) Expect(err).Should(Succeed()) - err = sut.DisableBlocking(500*time.Millisecond, []string{"group1"}) + err = sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{"group1"}) Expect(err).Should(Succeed()) Eventually(enabled, "1s").Should(Receive(BeFalse())) }) @@ -1086,7 +1086,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Disable blocking is called with wrong group name", func() { It("should fail", func() { - err := sut.DisableBlocking(500*time.Millisecond, []string{"unknownGroupName"}) + err := sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{"unknownGroupName"}) Expect(err).Should(HaveOccurred()) }) }) @@ -1094,7 +1094,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Blocking status is called", func() { It("should return correct status", func() { By("enable blocking via API", func() { - sut.EnableBlocking() + sut.EnableBlocking(context.TODO()) }) By("Query blocking status via API should return 'enabled'", func() { @@ -1103,7 +1103,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("disable blocking via API", func() { - err := sut.DisableBlocking(500*time.Millisecond, []string{}) + err := sut.DisableBlocking(context.TODO(), 500*time.Millisecond, []string{}) Expect(err).Should(Succeed()) }) @@ -1149,12 +1149,12 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { Expect(err).Should(Succeed()) - var rcfg config.RedisConfig + var rcfg config.Redis err = defaults.Set(&rcfg) Expect(err).Should(Succeed()) rcfg.Address = redisServer.Addr() - redisClient, err = redis.New(&rcfg) + redisClient, err = redis.New(context.TODO(), &rcfg) Expect(err).Should(Succeed()) Expect(redisClient).ShouldNot(BeNil()) @@ -1171,7 +1171,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) When("disable", func() { It("should return disable", func() { - sut.EnableBlocking() + sut.EnableBlocking(context.TODO()) redisMockMsg := &redis.EnabledMessage{ State: false, @@ -1185,7 +1185,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) When("disable", func() { It("should return disable", func() { - sut.EnableBlocking() + sut.EnableBlocking(context.TODO()) redisMockMsg := &redis.EnabledMessage{ State: false, Groups: []string{"unknown"}, @@ -1199,7 +1199,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) When("enable", func() { It("should return enable", func() { - err = sut.DisableBlocking(time.Hour, []string{}) + err = sut.DisableBlocking(context.TODO(), time.Hour, []string{}) Expect(err).Should(Succeed()) redisMockMsg := &redis.EnabledMessage{ diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 7851012a..d10ae278 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -60,7 +60,7 @@ func newCachingResolver(ctx context.Context, if c.redisClient != nil { go c.redisSubscriber(ctx) - c.redisClient.GetRedisCache() + c.redisClient.GetRedisCache(ctx) } return c diff --git a/resolver/caching_resolver_test.go b/resolver/caching_resolver_test.go index b7c53f21..3f88ef4b 100644 --- a/resolver/caching_resolver_test.go +++ b/resolver/caching_resolver_test.go @@ -733,7 +733,7 @@ var _ = Describe("CachingResolver", func() { var ( redisServer *miniredis.Miniredis redisClient *redis.Client - redisConfig *config.RedisConfig + redisConfig *config.Redis err error ) BeforeEach(func() { @@ -741,14 +741,14 @@ var _ = Describe("CachingResolver", func() { Expect(err).Should(Succeed()) - var rcfg config.RedisConfig + var rcfg config.Redis err = defaults.Set(&rcfg) Expect(err).Should(Succeed()) rcfg.Address = redisServer.Addr() redisConfig = &rcfg - redisClient, err = redis.New(redisConfig) + redisClient, err = redis.New(context.TODO(), redisConfig) Expect(err).Should(Succeed()) Expect(redisClient).ShouldNot(BeNil()) diff --git a/resolver/parallel_best_resolver_test.go b/resolver/parallel_best_resolver_test.go index c9c71c11..5b74fb44 100644 --- a/resolver/parallel_best_resolver_test.go +++ b/resolver/parallel_best_resolver_test.go @@ -452,14 +452,18 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { Describe("Weighted random on resolver selection", func() { When("4 upstream resolvers are defined", func() { + var ( + mockUpstream1 *MockUDPUpstreamServer + mockUpstream2 *MockUDPUpstreamServer + ) BeforeEach(func() { withError1 := config.Upstream{Host: "wrong1"} withError2 := config.Upstream{Host: "wrong2"} - mockUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") + mockUpstream1 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream1.Close) - mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") + mockUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream2.Close) upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2} diff --git a/resolver/strict_resolver_test.go b/resolver/strict_resolver_test.go index 950c9e62..b974044f 100644 --- a/resolver/strict_resolver_test.go +++ b/resolver/strict_resolver_test.go @@ -31,6 +31,10 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { ctx context.Context cancelFn context.CancelFunc + timeout = 2 * time.Second + + testUpstream1 *MockUDPUpstreamServer + testUpstream2 *MockUDPUpstreamServer ) Describe("Type", func() { @@ -58,6 +62,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { upstreamsCfg.StartVerify = sutVerify sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams) + sutConfig.Timeout = config.Duration(timeout) sut, err = NewStrictResolver(ctx, sutConfig, bootstrap) }) @@ -143,10 +148,10 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { When("Both are responding", func() { When("they respond in time", func() { BeforeEach(func() { - testUpstream1 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") + testUpstream1 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(testUpstream1.Close) - testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123") + testUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123") DeferCleanup(testUpstream2.Close) upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()} @@ -165,8 +170,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) When("first upstream exceeds upstreamTimeout", func() { BeforeEach(func() { - timeout := sut.cfg.Timeout.ToDuration() - testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { + testUpstream1 = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1") time.Sleep(2 * timeout) @@ -176,7 +180,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) DeferCleanup(testUpstream1.Close) - testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2") + testUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2") DeferCleanup(testUpstream2.Close) upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()} @@ -193,9 +197,8 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) }) When("all upstreams exceed upsteamTimeout", func() { - BeforeEach(func() { - timeout := sut.cfg.Timeout.ToDuration() - testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { + JustBeforeEach(func() { + testUpstream1 = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1") time.Sleep(2 * timeout) @@ -205,7 +208,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) DeferCleanup(testUpstream1.Close) - testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { + testUpstream2 = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2") time.Sleep(2 * timeout) @@ -225,12 +228,10 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) When("Only second is working", func() { BeforeEach(func() { - testUpstream1 := config.Upstream{Host: "wrong"} - - testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123") + testUpstream2 = NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.123") DeferCleanup(testUpstream2.Close) - upstreams = []config.Upstream{testUpstream1, testUpstream2.Start()} + upstreams = []config.Upstream{{Host: "wrong"}, testUpstream2.Start()} }) It("Should use result from second one", func() { request := newRequest("example.com.", A) diff --git a/server/server.go b/server/server.go index 8516ff78..c6241728 100644 --- a/server/server.go +++ b/server/server.go @@ -135,9 +135,12 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err return nil, err } - redisClient, redisErr := redis.New(&cfg.Redis) - if redisErr != nil && cfg.Redis.Required { - return nil, redisErr + var redisClient *redis.Client + if cfg.Redis.IsEnabled() { + redisClient, err = redis.New(ctx, &cfg.Redis) + if err != nil && cfg.Redis.Required { + return nil, err + } } queryResolver, queryError := createQueryResolver(ctx, cfg, bootstrap, redisClient) @@ -423,6 +426,11 @@ func (s *Server) registerDNSHandlers() { func (s *Server) printConfiguration() { logger().Info("current configuration:") + if s.cfg.Redis.IsEnabled() { + logger().Info("Redis:") + log.WithIndent(logger(), " ", s.cfg.Redis.LogConfig) + } + resolver.ForEach(s.queryResolver, func(res resolver.Resolver) { resolver.LogResolverConfig(res, logger()) }) diff --git a/util/context.go b/util/context.go new file mode 100644 index 00000000..4221ecae --- /dev/null +++ b/util/context.go @@ -0,0 +1,29 @@ +package util + +import "context" + +// CtxSend sends a value to a channel while the context isn't done. +// If the message is sent, it returns true. +// If the context is done or the channel is closed, it returns false. +func CtxSend[T any](ctx context.Context, ch chan T, val T) (ok bool) { + if ctx == nil || ch == nil || ctx.Err() != nil { + ok = false + + return + } + + defer func() { + if err := recover(); err != nil { + ok = false + } + }() + + select { + case <-ctx.Done(): + ok = false + case ch <- val: + ok = true + } + + return +} diff --git a/util/context_test.go b/util/context_test.go new file mode 100644 index 00000000..ccd3b0db --- /dev/null +++ b/util/context_test.go @@ -0,0 +1,125 @@ +package util + +import ( + "context" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +const ( + testMessage = 1 +) + +var _ = Describe("Context utils", func() { + Describe("CtxSend", func() { + var ch chan int + BeforeEach(func() { + ch = make(chan int, 1) + }) + + AfterEach(func() { + ch = nil + }) + + When("channel is not closed", func() { + It("should send value to channel", func(ctx context.Context) { + go startReader(ctx, ch) + Expect(CtxSend(ctx, ch, testMessage)).Should(BeTrue()) + }, SpecTimeout(time.Second)) + }) + + When("channel is closed", func() { + It("should return false", func(ctx context.Context) { + go startReader(ctx, ch) + close(ch) + Expect(CtxSend(ctx, ch, testMessage)).Should(BeFalse()) + }, SpecTimeout(time.Second)) + }) + + When("channel is nil", func() { + It("should return false", func(ctx context.Context) { + Expect(CtxSend(ctx, nil, testMessage)).Should(BeFalse()) + }, SpecTimeout(time.Second)) + }) + + When("channel is full", func() { + It("should wait", func(ctx context.Context) { + ch <- testMessage + go func(ctx context.Context, ch chan int) { + timer := time.NewTimer(time.Millisecond * 200) + select { + case <-timer.C: + startReader(ctx, ch) + case <-ctx.Done(): + return + } + }(ctx, ch) + Expect(CtxSend(ctx, ch, testMessage)).Should(BeTrue()) + }, SpecTimeout(time.Second)) + }) + + When("context is done", func() { + It("should return false", func(ctx context.Context) { + go startReader(ctx, ch) + cCtx, cancel := context.WithCancel(ctx) + cancel() + <-cCtx.Done() + Expect(CtxSend(cCtx, ch, testMessage)).Should(BeFalse()) + }, SpecTimeout(time.Second)) + }) + + When("context is terminated", func() { + It("should return false", func(ctx context.Context) { + ch <- testMessage + cCtx, cancel := context.WithCancel(ctx) + go func(ctx context.Context, cancel context.CancelFunc) { + timer := time.NewTimer(time.Millisecond * 200) + select { + case <-timer.C: + cancel() + case <-ctx.Done(): + return + } + }(ctx, cancel) + Expect(CtxSend(cCtx, ch, testMessage)).Should(BeFalse()) + }, SpecTimeout(time.Second)) + }) + + When("context is nil", func() { + It("should return false", func(ctx context.Context) { + Expect(CtxSend(nil, ch, testMessage)).Should(BeFalse()) + }, SpecTimeout(time.Second)) + }) + + When("context and channel are nil", func() { + It("should return false", func(ctx context.Context) { + Expect(CtxSend(nil, nil, testMessage)).Should(BeFalse()) + }, SpecTimeout(time.Second)) + }) + + When("context is done and channel is closed", func() { + It("should return false", func(ctx context.Context) { + go startReader(ctx, ch) + ctx, cancel := context.WithCancel(ctx) + cancel() + close(ch) + Expect(CtxSend(ctx, ch, testMessage)).Should(BeFalse()) + }, SpecTimeout(time.Second)) + }) + }) +}) + +func startReader(ctx context.Context, ch <-chan int) { + for { + select { + case <-ctx.Done(): + return + case i, ok := <-ch: + if ok { + Expect(i).Should(Equal(testMessage)) + } + } + } +}