diff --git a/cache/expirationcache/expiration_cache.go b/cache/expirationcache/expiration_cache.go index 01b13449..3dec2514 100644 --- a/cache/expirationcache/expiration_cache.go +++ b/cache/expirationcache/expiration_cache.go @@ -1,6 +1,7 @@ package expirationcache import ( + "context" "time" lru "github.com/hashicorp/golang-lru" @@ -47,11 +48,11 @@ type OnCacheMissCallback func(key string) // OnAfterPutCallback will be called after put, receives new element count as parameter type OnAfterPutCallback func(newSize int) -func NewCache[T any](options Options) *ExpiringLRUCache[T] { - return NewCacheWithOnExpired[T](options, nil) +func NewCache[T any](ctx context.Context, options Options) *ExpiringLRUCache[T] { + return NewCacheWithOnExpired[T](ctx, options, nil) } -func NewCacheWithOnExpired[T any](options Options, +func NewCacheWithOnExpired[T any](ctx context.Context, options Options, onExpirationFn OnExpirationCallback[T], ) *ExpiringLRUCache[T] { l, _ := lru.New(defaultSize) @@ -90,18 +91,22 @@ func NewCacheWithOnExpired[T any](options Options, c.preExpirationFn = onExpirationFn } - go periodicCleanup(c) + go periodicCleanup(ctx, c) return c } -func periodicCleanup[T any](c *ExpiringLRUCache[T]) { +func periodicCleanup[T any](ctx context.Context, c *ExpiringLRUCache[T]) { ticker := time.NewTicker(c.cleanUpInterval) defer ticker.Stop() for { - <-ticker.C - c.cleanUp() + select { + case <-ticker.C: + c.cleanUp() + case <-ctx.Done(): + return + } } } diff --git a/cache/expirationcache/expiration_cache_test.go b/cache/expirationcache/expiration_cache_test.go index dad89e62..9364350b 100644 --- a/cache/expirationcache/expiration_cache_test.go +++ b/cache/expirationcache/expiration_cache_test.go @@ -1,6 +1,7 @@ package expirationcache import ( + "context" "time" . "github.com/onsi/ginkgo/v2" @@ -8,14 +9,22 @@ import ( ) var _ = Describe("Expiration cache", func() { + var ( + ctx context.Context + cancelFn context.CancelFunc + ) + BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + }) Describe("Basic operations", func() { When("string cache was created", func() { It("Initial cache should be empty", func() { - cache := NewCache[string](Options{}) + cache := NewCache[string](ctx, Options{}) Expect(cache.TotalCount()).Should(Equal(0)) }) It("Initial cache should not contain any elements", func() { - cache := NewCache[string](Options{}) + cache := NewCache[string](ctx, Options{}) val, expiration := cache.Get("key1") Expect(val).Should(BeNil()) Expect(expiration).Should(Equal(time.Duration(0))) @@ -23,7 +32,7 @@ var _ = Describe("Expiration cache", func() { }) When("Put new value with positive TTL", func() { It("Should return the value before element expires", func() { - cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond}) + cache := NewCache[string](ctx, Options{CleanupInterval: 100 * time.Millisecond}) v := "v1" cache.Put("key1", &v, 50*time.Millisecond) val, expiration := cache.Get("key1") @@ -33,7 +42,7 @@ var _ = Describe("Expiration cache", func() { Expect(cache.TotalCount()).Should(Equal(1)) }) It("Should return nil after expiration", func() { - cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond}) + cache := NewCache[string](ctx, Options{CleanupInterval: 100 * time.Millisecond}) v := "v1" cache.Put("key1", &v, 50*time.Millisecond) @@ -47,12 +56,12 @@ var _ = Describe("Expiration cache", func() { // wait for cleanup run Eventually(func() int { return cache.lru.Len() - }, "100ms").Should(Equal(0)) + }).Should(Equal(0)) }) }) When("Put new value without expiration", func() { It("Should not cache the value", func() { - cache := NewCache[string](Options{CleanupInterval: 50 * time.Millisecond}) + cache := NewCache[string](ctx, Options{CleanupInterval: 50 * time.Millisecond}) v := "x" cache.Put("key1", &v, 0) val, expiration := cache.Get("key1") @@ -63,7 +72,7 @@ var _ = Describe("Expiration cache", func() { }) When("Put updated value", func() { It("Should return updated value", func() { - cache := NewCache[string](Options{}) + cache := NewCache[string](ctx, Options{}) v1 := "v1" v2 := "v2" cache.Put("key1", &v1, 50*time.Millisecond) @@ -79,7 +88,7 @@ var _ = Describe("Expiration cache", func() { }) When("Purging after usage", func() { It("Should be empty after purge", func() { - cache := NewCache[string](Options{}) + cache := NewCache[string](ctx, Options{}) v1 := "y" cache.Put("key1", &v1, time.Second) @@ -97,7 +106,7 @@ var _ = Describe("Expiration cache", func() { onCacheHitChannel := make(chan string, 10) onCacheMissChannel := make(chan string, 10) onAfterPutChannel := make(chan int, 10) - cache := NewCache[string](Options{ + cache := NewCache[string](ctx, Options{ OnCacheHitFn: func(key string) { onCacheHitChannel <- key }, @@ -146,7 +155,7 @@ var _ = Describe("Expiration cache", func() { return &v2, time.Second } - cache := NewCacheWithOnExpired[string](Options{}, fn) + cache := NewCacheWithOnExpired[string](ctx, Options{}, fn) v1 := "v1" cache.Put("key1", &v1, 50*time.Millisecond) @@ -165,7 +174,7 @@ var _ = Describe("Expiration cache", func() { return &v2, time.Second } - cache := NewCacheWithOnExpired[string](Options{}, fn) + cache := NewCacheWithOnExpired[string](ctx, Options{}, fn) v1 := "somval" cache.Put("key1", &v1, time.Millisecond) @@ -186,7 +195,7 @@ var _ = Describe("Expiration cache", func() { fn := func(key string) (val *string, ttl time.Duration) { return nil, 0 } - cache := NewCacheWithOnExpired[string](Options{CleanupInterval: 100 * time.Microsecond}, fn) + cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn) v1 := "z" cache.Put("key1", &v1, 50*time.Millisecond) @@ -199,7 +208,7 @@ var _ = Describe("Expiration cache", func() { Describe("LRU behaviour", func() { When("Defined max size is reached", func() { It("should remove old elements", func() { - cache := NewCache[string](Options{MaxSize: 3}) + cache := NewCache[string](ctx, Options{MaxSize: 3}) v1 := "val1" v2 := "val2" diff --git a/cache/expirationcache/prefetching_cache.go b/cache/expirationcache/prefetching_cache.go index 1dec5d09..e30021d7 100644 --- a/cache/expirationcache/prefetching_cache.go +++ b/cache/expirationcache/prefetching_cache.go @@ -1,6 +1,7 @@ package expirationcache import ( + "context" "sync/atomic" "time" ) @@ -39,9 +40,9 @@ type PrefetchingOptions[T any] struct { type PrefetchingCacheOption[T any] func(c *PrefetchingExpiringLRUCache[cacheValue[T]]) -func NewPrefetchingCache[T any](options PrefetchingOptions[T]) *PrefetchingExpiringLRUCache[T] { +func NewPrefetchingCache[T any](ctx context.Context, options PrefetchingOptions[T]) *PrefetchingExpiringLRUCache[T] { pc := &PrefetchingExpiringLRUCache[T]{ - prefetchingNameCache: NewCache[atomic.Uint32](Options{ + prefetchingNameCache: NewCache[atomic.Uint32](ctx, Options{ CleanupInterval: time.Minute, MaxSize: uint(options.PrefetchMaxItemsCount), OnAfterPutFn: options.OnPrefetchAfterPut, @@ -53,7 +54,7 @@ func NewPrefetchingCache[T any](options PrefetchingOptions[T]) *PrefetchingExpir onPrefetchCacheHit: options.OnPrefetchCacheHit, } - pc.cache = NewCacheWithOnExpired[cacheValue[T]](options.Options, pc.onExpired) + pc.cache = NewCacheWithOnExpired[cacheValue[T]](ctx, options.Options, pc.onExpired) return pc } diff --git a/cache/expirationcache/prefetching_cache_test.go b/cache/expirationcache/prefetching_cache_test.go index 7a1ef80b..91b796a1 100644 --- a/cache/expirationcache/prefetching_cache_test.go +++ b/cache/expirationcache/prefetching_cache_test.go @@ -1,6 +1,7 @@ package expirationcache import ( + "context" "time" . "github.com/onsi/ginkgo/v2" @@ -8,21 +9,29 @@ import ( ) var _ = Describe("Prefetching expiration cache", func() { + var ( + ctx context.Context + cancelFn context.CancelFunc + ) + BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + }) Describe("Basic operations", func() { When("string cache was created", func() { It("Initial cache should be empty", func() { - cache := NewPrefetchingCache[string](PrefetchingOptions[string]{}) + cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{}) Expect(cache.TotalCount()).Should(Equal(0)) }) It("Initial cache should not contain any elements", func() { - cache := NewPrefetchingCache[string](PrefetchingOptions[string]{}) + cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{}) val, expiration := cache.Get("key1") Expect(val).Should(BeNil()) Expect(expiration).Should(Equal(time.Duration(0))) }) It("Should work as cache (basic operations)", func() { - cache := NewPrefetchingCache[string](PrefetchingOptions[string]{}) + cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{}) v := "v1" cache.Put("key1", &v, 50*time.Millisecond) @@ -39,7 +48,7 @@ var _ = Describe("Prefetching expiration cache", func() { }) Context("Prefetching", func() { It("Should prefetch element", func() { - cache := NewPrefetchingCache[string](PrefetchingOptions[string]{ + cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{ Options: Options{ CleanupInterval: 100 * time.Millisecond, }, @@ -71,7 +80,7 @@ var _ = Describe("Prefetching expiration cache", func() { }) }) It("Should not prefetch element", func() { - cache := NewPrefetchingCache[string](PrefetchingOptions[string]{ + cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{ Options: Options{ CleanupInterval: 100 * time.Millisecond, }, @@ -100,7 +109,7 @@ var _ = Describe("Prefetching expiration cache", func() { }) }) It("With default config (threshold = 0) should always prefetch", func() { - cache := NewPrefetchingCache[string](PrefetchingOptions[string]{ + cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{ Options: Options{ CleanupInterval: 100 * time.Millisecond, }, @@ -128,7 +137,7 @@ var _ = Describe("Prefetching expiration cache", func() { onPrefetchAfterPutChannel := make(chan int, 10) onPrefetchEntryReloaded := make(chan string, 10) onnPrefetchCacheHit := make(chan string, 10) - cache := NewPrefetchingCache[string](PrefetchingOptions[string]{ + cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{ Options: Options{ CleanupInterval: 100 * time.Millisecond, }, diff --git a/cmd/serve.go b/cmd/serve.go index 028229f3..b12e74d1 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "os" "os/signal" @@ -43,7 +44,10 @@ func startServer(_ *cobra.Command, _ []string) error { signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) - srv, err := server.NewServer(cfg) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + + srv, err := server.NewServer(ctx, cfg) if err != nil { return fmt.Errorf("can't start server: %w", err) } @@ -51,7 +55,7 @@ func startServer(_ *cobra.Command, _ []string) error { const errChanSize = 10 errChan := make(chan error, errChanSize) - srv.Start(errChan) + srv.Start(ctx, errChan) var terminationErr error diff --git a/cmd/serve_test.go b/cmd/serve_test.go index 5a90c598..a1594927 100644 --- a/cmd/serve_test.go +++ b/cmd/serve_test.go @@ -93,7 +93,7 @@ var _ = Describe("Serve command", func() { By("terminate with signal", func() { var startError error - Eventually(errChan).Should(Receive(&startError)) + Eventually(errChan, "10s").Should(Receive(&startError)) Expect(startError).ShouldNot(BeNil()) Expect(startError.Error()).Should(ContainSubstring("address already in use")) }) diff --git a/config/config.go b/config/config.go index 44853b40..1f181c38 100644 --- a/config/config.go +++ b/config/config.go @@ -314,7 +314,10 @@ func (c *SourceLoadingConfig) LogConfig(logger *logrus.Entry) { log.WithIndent(logger, " ", c.Downloads.LogConfig) } -func (c *SourceLoadingConfig) StartPeriodicRefresh(refresh func(context.Context) error, logErr func(error)) error { +func (c *SourceLoadingConfig) StartPeriodicRefresh(ctx context.Context, + refresh func(context.Context) error, + logErr func(error), +) error { refreshAndRecover := func(ctx context.Context) (rerr error) { defer func() { if val := recover(); val != nil { @@ -331,20 +334,29 @@ func (c *SourceLoadingConfig) StartPeriodicRefresh(refresh func(context.Context) } if c.RefreshPeriod > 0 { - go c.periodically(refreshAndRecover, logErr) + go c.periodically(ctx, refreshAndRecover, logErr) } return nil } -func (c *SourceLoadingConfig) periodically(refresh func(context.Context) error, logErr func(error)) { +func (c *SourceLoadingConfig) periodically(ctx context.Context, + refresh func(context.Context) error, + logErr func(error), +) { ticker := time.NewTicker(c.RefreshPeriod.ToDuration()) defer ticker.Stop() - for range ticker.C { - err := refresh(context.Background()) - if err != nil { - logErr(err) + for { + select { + case <-ticker.C: + err := refresh(ctx) + if err != nil { + logErr(err) + } + + case <-ctx.Done(): + return } } } diff --git a/config/config_test.go b/config/config_test.go index 7e273b56..e59c2076 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -688,6 +688,14 @@ bootstrapDns: }) Describe("SourceLoadingConfig", func() { + var ( + ctx context.Context + cancelFn context.CancelFunc + ) + BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + }) It("handles panics", func() { sut := SourceLoadingConfig{ Strategy: StartStrategyTypeFailOnError, @@ -695,7 +703,7 @@ bootstrapDns: panicMsg := "panic value" - err := sut.StartPeriodicRefresh(func(context.Context) error { + err := sut.StartPeriodicRefresh(ctx, func(context.Context) error { panic(panicMsg) }, func(err error) { Expect(err).Should(MatchError(ContainSubstring(panicMsg))) @@ -715,7 +723,7 @@ bootstrapDns: var call atomic.Int32 - err := sut.StartPeriodicRefresh(func(context.Context) error { + err := sut.StartPeriodicRefresh(ctx, func(context.Context) error { call := call.Add(1) calls <- call diff --git a/lists/list_cache.go b/lists/list_cache.go index 8d619ba5..f731b015 100644 --- a/lists/list_cache.go +++ b/lists/list_cache.go @@ -56,7 +56,7 @@ func (b *ListCache) LogConfig(logger *logrus.Entry) { } // NewListCache creates new list instance -func NewListCache( +func NewListCache(ctx context.Context, t ListCacheType, cfg config.SourceLoadingConfig, groupSources map[string][]config.BytesSource, downloader FileDownloader, ) (*ListCache, error) { @@ -72,7 +72,7 @@ func NewListCache( downloader: downloader, } - err := cfg.StartPeriodicRefresh(c.refresh, func(err error) { + err := cfg.StartPeriodicRefresh(ctx, c.refresh, func(err error) { logger().WithError(err).Errorf("could not init %s", t) }) if err != nil { diff --git a/lists/list_cache_benchmark_test.go b/lists/list_cache_benchmark_test.go index 8a82c9a1..1a710250 100644 --- a/lists/list_cache_benchmark_test.go +++ b/lists/list_cache_benchmark_test.go @@ -1,6 +1,7 @@ package lists import ( + "context" "testing" "github.com/0xERR0R/blocky/config" @@ -19,7 +20,7 @@ func BenchmarkRefresh(b *testing.B) { RefreshPeriod: config.Duration(-1), } downloader := NewDownloader(config.DownloaderConfig{}, nil) - cache, _ := NewListCache(ListCacheTypeBlacklist, cfg, lists, downloader) + cache, _ := NewListCache(context.Background(), ListCacheTypeBlacklist, cfg, lists, downloader) b.ReportAllocs() diff --git a/lists/list_cache_test.go b/lists/list_cache_test.go index c5426b47..acbda197 100644 --- a/lists/list_cache_test.go +++ b/lists/list_cache_test.go @@ -36,11 +36,16 @@ var _ = Describe("ListCache", func() { lists map[string][]config.BytesSource downloader FileDownloader mockDownloader *MockDownloader + ctx context.Context + cancelFn context.CancelFunc ) BeforeEach(func() { var err error + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + listCacheType = ListCacheTypeBlacklist sutConfig, err = config.WithDefaults[config.SourceLoadingConfig]() @@ -84,7 +89,7 @@ var _ = Describe("ListCache", func() { downloader = mockDownloader } - sut, err = NewListCache(listCacheType, sutConfig, lists, downloader) + sut, err = NewListCache(ctx, listCacheType, sutConfig, lists, downloader) Expect(err).Should(Succeed()) }) @@ -304,7 +309,7 @@ var _ = Describe("ListCache", func() { "gr1": config.NewBytesSources(file1, file2, file3), } - sut, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader) + sut, err := NewListCache(ctx, ListCacheTypeBlacklist, sutConfig, lists, downloader) Expect(err).Should(Succeed()) Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(lines1 + lines2 + lines3)) @@ -359,7 +364,7 @@ var _ = Describe("ListCache", func() { }, } - _, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader) + _, err := NewListCache(ctx, ListCacheTypeBlacklist, sutConfig, lists, downloader) Expect(err).ShouldNot(Succeed()) Expect(err).Should(MatchError(parsers.ErrTooManyErrors)) }) @@ -408,7 +413,7 @@ var _ = Describe("ListCache", func() { "gr2": {config.TextBytesSource("inline", "definition")}, } - sut, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader) + sut, err := NewListCache(ctx, ListCacheTypeBlacklist, sutConfig, lists, downloader) Expect(err).Should(Succeed()) sut.LogConfig(logger) @@ -430,7 +435,7 @@ var _ = Describe("ListCache", func() { "gr1": config.NewBytesSources("doesnotexist"), } - _, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader) + _, err := NewListCache(ctx, ListCacheTypeBlacklist, sutConfig, lists, downloader) Expect(err).Should(Succeed()) }) }) diff --git a/querylog/database_writer.go b/querylog/database_writer.go index 30dbcf57..f7856e7e 100644 --- a/querylog/database_writer.go +++ b/querylog/database_writer.go @@ -1,6 +1,7 @@ package querylog import ( + "context" "fmt" "reflect" "strings" @@ -43,20 +44,20 @@ type DatabaseWriter struct { dbFlushPeriod time.Duration } -func NewDatabaseWriter(dbType, target string, logRetentionDays uint64, +func NewDatabaseWriter(ctx context.Context, dbType, target string, logRetentionDays uint64, dbFlushPeriod time.Duration, ) (*DatabaseWriter, error) { switch dbType { case "mysql": - return newDatabaseWriter(mysql.Open(target), logRetentionDays, dbFlushPeriod) + return newDatabaseWriter(ctx, mysql.Open(target), logRetentionDays, dbFlushPeriod) case "postgresql": - return newDatabaseWriter(postgres.Open(target), logRetentionDays, dbFlushPeriod) + return newDatabaseWriter(ctx, postgres.Open(target), logRetentionDays, dbFlushPeriod) } return nil, fmt.Errorf("incorrect database type provided: %s", dbType) } -func newDatabaseWriter(target gorm.Dialector, logRetentionDays uint64, +func newDatabaseWriter(ctx context.Context, target gorm.Dialector, logRetentionDays uint64, dbFlushPeriod time.Duration, ) (*DatabaseWriter, error) { db, err := gorm.Open(target, &gorm.Config{ @@ -84,7 +85,7 @@ func newDatabaseWriter(target gorm.Dialector, logRetentionDays uint64, dbFlushPeriod: dbFlushPeriod, } - go w.periodicFlush() + go w.periodicFlush(ctx) return w, nil } @@ -118,16 +119,20 @@ func databaseMigration(db *gorm.DB) error { return nil } -func (d *DatabaseWriter) periodicFlush() { +func (d *DatabaseWriter) periodicFlush(ctx context.Context) { ticker := time.NewTicker(d.dbFlushPeriod) defer ticker.Stop() for { - <-ticker.C + select { + case <-ticker.C: + err := d.doDBWrite() - err := d.doDBWrite() + util.LogOnError("can't write entries to the database: ", err) - util.LogOnError("can't write entries to the database: ", err) + case <-ctx.Done(): + return + } } } diff --git a/querylog/database_writer_test.go b/querylog/database_writer_test.go index df7618a3..6d2acbc5 100644 --- a/querylog/database_writer_test.go +++ b/querylog/database_writer_test.go @@ -1,6 +1,7 @@ package querylog import ( + "context" "database/sql" "fmt" "time" @@ -20,19 +21,26 @@ import ( var err error var _ = Describe("DatabaseWriter", func() { + var ( + ctx context.Context + cancelFn context.CancelFunc + ) + BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + }) Describe("Database query log to sqlite", func() { var ( sqliteDB gorm.Dialector writer *DatabaseWriter ) - BeforeEach(func() { sqliteDB = sqlite.Open("file::memory:") }) When("New log entry was created", func() { BeforeEach(func() { - writer, err = newDatabaseWriter(sqliteDB, 7, time.Millisecond) + writer, err = newDatabaseWriter(ctx, sqliteDB, 7, time.Millisecond) Expect(err).Should(Succeed()) db, err := writer.db.DB() @@ -83,7 +91,7 @@ var _ = Describe("DatabaseWriter", func() { When("> 10000 Entries were created", func() { BeforeEach(func() { - writer, err = newDatabaseWriter(sqliteDB, 7, time.Millisecond) + writer, err = newDatabaseWriter(ctx, sqliteDB, 7, time.Millisecond) Expect(err).Should(Succeed()) }) @@ -114,7 +122,7 @@ var _ = Describe("DatabaseWriter", func() { When("There are log entries with timestamp exceeding the retention period", func() { BeforeEach(func() { - writer, err = newDatabaseWriter(sqliteDB, 1, time.Millisecond) + writer, err = newDatabaseWriter(ctx, sqliteDB, 1, time.Millisecond) Expect(err).Should(Succeed()) }) @@ -162,7 +170,7 @@ var _ = Describe("DatabaseWriter", func() { Describe("Database query log fails", func() { When("mysql connection parameters wrong", func() { It("should be log with fatal", func() { - _, err := NewDatabaseWriter("mysql", "wrong param", 7, 1) + _, err := NewDatabaseWriter(ctx, "mysql", "wrong param", 7, 1) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(HavePrefix("can't create database connection")) }) @@ -170,7 +178,7 @@ var _ = Describe("DatabaseWriter", func() { When("postgresql connection parameters wrong", func() { It("should be log with fatal", func() { - _, err := NewDatabaseWriter("postgresql", "wrong param", 7, 1) + _, err := NewDatabaseWriter(ctx, "postgresql", "wrong param", 7, 1) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(HavePrefix("can't create database connection")) }) @@ -178,7 +186,7 @@ var _ = Describe("DatabaseWriter", func() { When("invalid database type is specified", func() { It("should be log with fatal", func() { - _, err := NewDatabaseWriter("invalidsql", "", 7, 1) + _, err := NewDatabaseWriter(ctx, "invalidsql", "", 7, 1) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(HavePrefix("incorrect database type provided")) }) @@ -225,7 +233,7 @@ var _ = Describe("DatabaseWriter", func() { mock.ExpectExec(`ALTER TABLE log_entries ADD column if not exists id serial primary key`).WillReturnResult(sqlmock.NewResult(0, 0)) }) - _, err = newDatabaseWriter(dlc, 1, time.Millisecond) + _, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond) Expect(err).Should(Succeed()) }) }) @@ -257,7 +265,7 @@ var _ = Describe("DatabaseWriter", func() { mock.ExpectExec("ALTER TABLE `log_entries` ADD `id` INT PRIMARY KEY AUTO_INCREMENT").WillReturnResult(sqlmock.NewResult(0, 0)) }) - _, err = newDatabaseWriter(dlc, 1, time.Millisecond) + _, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond) Expect(err).Should(Succeed()) }) }) @@ -273,7 +281,7 @@ var _ = Describe("DatabaseWriter", func() { mock.ExpectExec("ALTER TABLE `log_entries` ADD `id` INT PRIMARY KEY AUTO_INCREMENT").WillReturnError(fmt.Errorf("error 1060: duplicate column name")) }) - _, err = newDatabaseWriter(dlc, 1, time.Millisecond) + _, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond) Expect(err).Should(Succeed()) }) @@ -286,7 +294,7 @@ var _ = Describe("DatabaseWriter", func() { mock.ExpectExec("ALTER TABLE `log_entries` ADD `id` INT PRIMARY KEY AUTO_INCREMENT").WillReturnError(fmt.Errorf("error XXX: some index error")) }) - _, err = newDatabaseWriter(dlc, 1, time.Millisecond) + _, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(ContainSubstring("can't perform auto migration: error XXX: some index error")) }) @@ -298,7 +306,7 @@ var _ = Describe("DatabaseWriter", func() { mock.ExpectExec("CREATE TABLE `log_entries`").WillReturnError(fmt.Errorf("error XXX: some db error")) }) - _, err = newDatabaseWriter(dlc, 1, time.Millisecond) + _, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(ContainSubstring("can't perform auto migration: error XXX: some db error")) }) diff --git a/resolver/blocking_resolver.go b/resolver/blocking_resolver.go index 22e06d00..2a6a33fb 100644 --- a/resolver/blocking_resolver.go +++ b/resolver/blocking_resolver.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "fmt" "net" "sort" @@ -110,8 +111,10 @@ func clientGroupsBlock(cfg config.BlockingConfig) map[string][]string { } // NewBlockingResolver returns a new configured instance of the resolver -func NewBlockingResolver( - cfg config.BlockingConfig, redis *redis.Client, bootstrap *Bootstrap, +func NewBlockingResolver(ctx context.Context, + cfg config.BlockingConfig, + redis *redis.Client, + bootstrap *Bootstrap, ) (r *BlockingResolver, err error) { blockHandler, err := createBlockHandler(cfg) if err != nil { @@ -120,8 +123,10 @@ func NewBlockingResolver( downloader := lists.NewDownloader(cfg.Loading.Downloads, bootstrap.NewHTTPTransport()) - blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.Loading, cfg.BlackLists, downloader) - whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.Loading, cfg.WhiteLists, downloader) + blacklistMatcher, blErr := lists.NewListCache(ctx, lists.ListCacheTypeBlacklist, + cfg.Loading, cfg.BlackLists, downloader) + whitelistMatcher, wlErr := lists.NewListCache(ctx, lists.ListCacheTypeWhitelist, + cfg.Loading, cfg.WhiteLists, downloader) whitelistOnlyGroups := determineWhitelistOnlyGroups(&cfg) err = multierror.Append(err, blErr, wlErr).ErrorOrNil() @@ -145,14 +150,14 @@ func NewBlockingResolver( redisClient: redis, } - res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](expirationcache.Options{ + res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](ctx, expirationcache.Options{ CleanupInterval: defaultBlockingCleanUpInterval, }, func(key string) (val *[]net.IP, ttl time.Duration) { return res.queryForFQIdentifierIPs(key) }) if res.redisClient != nil { - setupRedisEnabledSubscriber(res) + setupRedisEnabledSubscriber(ctx, res) } err = evt.Bus().SubscribeOnce(evt.ApplicationStarted, func(_ ...string) { @@ -166,20 +171,26 @@ func NewBlockingResolver( return res, nil } -func setupRedisEnabledSubscriber(c *BlockingResolver) { +func setupRedisEnabledSubscriber(ctx context.Context, c *BlockingResolver) { go func() { - for em := range c.redisClient.EnabledChannel { - if em != nil { - c.log().Debug("Received state from redis: ", em) + for { + select { + case em := <-c.redisClient.EnabledChannel: + if em != nil { + c.log().Debug("Received state from redis: ", em) - if em.State { - c.internalEnableBlocking() - } else { - err := c.internalDisableBlocking(em.Duration, em.Groups) - if err != nil { - c.log().Warn("Blocking couldn't be disabled:", err) + if em.State { + c.internalEnableBlocking() + } else { + err := c.internalDisableBlocking(em.Duration, em.Groups) + if err != nil { + c.log().Warn("Blocking couldn't be disabled:", err) + } } } + + case <-ctx.Done(): + return } } }() diff --git a/resolver/blocking_resolver_test.go b/resolver/blocking_resolver_test.go index 899fc972..b7004fbe 100644 --- a/resolver/blocking_resolver_test.go +++ b/resolver/blocking_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "time" "github.com/0xERR0R/blocky/config" @@ -50,6 +51,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { sutConfig config.BlockingConfig m *mockResolver mockAnswer *dns.Msg + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -59,6 +62,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sutConfig = config.BlockingConfig{ BlockType: "ZEROIP", BlockTTL: config.Duration(time.Minute), @@ -72,7 +78,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil) - sut, err = NewBlockingResolver(sutConfig, nil, systemResolverBootstrap) + + sut, err = NewBlockingResolver(ctx, sutConfig, nil, systemResolverBootstrap) Expect(err).Should(Succeed()) sut.Next(m) }) @@ -113,7 +120,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { Expect(err).Should(Succeed()) // recreate to trigger a reload - sut, err = NewBlockingResolver(sutConfig, nil, systemResolverBootstrap) + sut, err = NewBlockingResolver(ctx, sutConfig, nil, systemResolverBootstrap) Expect(err).Should(Succeed()) Eventually(groupCnt, "1s").Should(HaveLen(2)) @@ -1111,7 +1118,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { Describe("Create resolver with wrong parameter", func() { When("Wrong blockType is used", func() { It("should return error", func() { - _, err := NewBlockingResolver(config.BlockingConfig{ + _, err := NewBlockingResolver(ctx, config.BlockingConfig{ BlockType: "wrong", }, nil, systemResolverBootstrap) @@ -1121,7 +1128,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) When("strategy is failOnError", func() { It("should fail if lists can't be downloaded", func() { - _, err := NewBlockingResolver(config.BlockingConfig{ + _, err := NewBlockingResolver(ctx, config.BlockingConfig{ BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources("wrongPath")}, WhiteLists: map[string][]config.BytesSource{"whitelist": config.NewBytesSources("wrongPath")}, Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFailOnError}, @@ -1155,7 +1162,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BlockTTL: config.Duration(time.Minute), } - sut, err = NewBlockingResolver(sutConfig, redisClient, systemResolverBootstrap) + sut, err = NewBlockingResolver(ctx, sutConfig, redisClient, systemResolverBootstrap) Expect(err).Should(Succeed()) }) JustAfterEach(func() { diff --git a/resolver/bootstrap.go b/resolver/bootstrap.go index aed6f7a2..1d9a1c64 100644 --- a/resolver/bootstrap.go +++ b/resolver/bootstrap.go @@ -42,7 +42,7 @@ type Bootstrap struct { // NewBootstrap creates and returns a new Bootstrap. // Internally, it uses a CachingResolver and an UpstreamResolver. -func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) { +func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err error) { logger := log.PrefixedLog("bootstrap") timeout := defaultTimeout @@ -97,7 +97,8 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) { b.resolver = Chain( NewFilteringResolver(cfg.Filtering), - newCachingResolver(cachingCfg, nil, false), // false: no metrics, to not overwrite the main blocking resolver ones + // false: no metrics, to not overwrite the main blocking resolver ones + newCachingResolver(ctx, cachingCfg, nil, false), parallelResolver, ) diff --git a/resolver/bootstrap_test.go b/resolver/bootstrap_test.go index f71adddb..b1e1477a 100644 --- a/resolver/bootstrap_test.go +++ b/resolver/bootstrap_test.go @@ -25,6 +25,8 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { var ( sut *Bootstrap sutConfig *config.Config + ctx context.Context + cancelFn context.CancelFunc err error ) @@ -41,10 +43,13 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { }, }, } + + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) }) JustBeforeEach(func() { - sut, err = NewBootstrap(sutConfig) + sut, err = NewBootstrap(ctx, sutConfig) Expect(err).Should(Succeed()) }) @@ -98,7 +103,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { }, } - _, err := NewBootstrap(&cfg) + _, err := NewBootstrap(ctx, &cfg) Expect(err).ShouldNot(Succeed()) }) }) @@ -140,7 +145,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { }, } - _, err := NewBootstrap(&cfg) + _, err := NewBootstrap(ctx, &cfg) Expect(err).ShouldNot(Succeed()) Expect(err.Error()).Should(ContainSubstring("must use IP instead of hostname")) }) @@ -184,7 +189,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { }, } - _, err := NewBootstrap(&cfg) + _, err := NewBootstrap(ctx, &cfg) Expect(err).ShouldNot(Succeed()) Expect(err.Error()).Should(ContainSubstring("no IPs configured")) }) @@ -203,7 +208,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { }, } - _, err := NewBootstrap(&cfg) + _, err := NewBootstrap(ctx, &cfg) Expect(err).Should(Succeed()) }) }) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 427d9964..00397898 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "fmt" "math" "sync/atomic" @@ -35,11 +36,18 @@ type CachingResolver struct { } // NewCachingResolver creates a new resolver instance -func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) *CachingResolver { - return newCachingResolver(cfg, redis, true) +func NewCachingResolver(ctx context.Context, + cfg config.CachingConfig, + redis *redis.Client, +) *CachingResolver { + return newCachingResolver(ctx, cfg, redis, true) } -func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetricEvents bool) *CachingResolver { +func newCachingResolver(ctx context.Context, + cfg config.CachingConfig, + redis *redis.Client, + emitMetricEvents bool, +) *CachingResolver { c := &CachingResolver{ configurable: withConfig(&cfg), typed: withType("caching"), @@ -48,17 +56,17 @@ func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetri emitMetricEvents: emitMetricEvents, } - configureCaches(c, &cfg) + configureCaches(ctx, c, &cfg) if c.redisClient != nil { - setupRedisCacheSubscriber(c) + setupRedisCacheSubscriber(ctx, c) c.redisClient.GetRedisCache() } return c } -func configureCaches(c *CachingResolver, cfg *config.CachingConfig) { +func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.CachingConfig) { options := expirationcache.Options{ CleanupInterval: defaultCachingCleanUpInterval, MaxSize: uint(cfg.MaxItemsCount), @@ -91,9 +99,9 @@ func configureCaches(c *CachingResolver, cfg *config.CachingConfig) { }, } - c.resultCache = expirationcache.NewPrefetchingCache(prefetchingOptions) + c.resultCache = expirationcache.NewPrefetchingCache(ctx, prefetchingOptions) } else { - c.resultCache = expirationcache.NewCache[dns.Msg](options) + c.resultCache = expirationcache.NewCache[dns.Msg](ctx, options) } } @@ -117,13 +125,19 @@ func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*dns.Msg, time.Dura return nil, 0 } -func setupRedisCacheSubscriber(c *CachingResolver) { +func setupRedisCacheSubscriber(ctx context.Context, c *CachingResolver) { go func() { - for rc := range c.redisClient.CacheChannel { - if rc != nil { - c.log().Debug("Received key from redis: ", rc.Key) - ttl := c.adjustTTLs(rc.Response.Res.Answer) - c.putInCache(rc.Key, rc.Response, ttl, false) + for { + select { + case rc := <-c.redisClient.CacheChannel: + if rc != nil { + c.log().Debug("Received key from redis: ", rc.Key) + ttl := c.adjustTTLs(rc.Response.Res.Answer) + c.putInCache(rc.Key, rc.Response, ttl, false) + } + + case <-ctx.Done(): + return } } }() diff --git a/resolver/caching_resolver_test.go b/resolver/caching_resolver_test.go index 63a7333c..acc7be64 100644 --- a/resolver/caching_resolver_test.go +++ b/resolver/caching_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "fmt" "time" @@ -26,6 +27,8 @@ var _ = Describe("CachingResolver", func() { sutConfig config.CachingConfig m *mockResolver mockAnswer *dns.Msg + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -43,7 +46,10 @@ var _ = Describe("CachingResolver", func() { }) JustBeforeEach(func() { - sut = NewCachingResolver(sutConfig, nil) + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + + sut = NewCachingResolver(ctx, sutConfig, nil) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil) sut.Next(m) @@ -79,7 +85,7 @@ var _ = Describe("CachingResolver", func() { It("should prefetch domain if query count > threshold", func() { // prepare resolver, set smaller caching times for testing prefetchThreshold := 5 - configureCaches(sut, &sutConfig) + configureCaches(ctx, sut, &sutConfig) domainPrefetched := make(chan bool, 1) prefetchHitDomain := make(chan bool, 1) @@ -725,7 +731,7 @@ var _ = Describe("CachingResolver", func() { } mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 1000, A, "1.1.1.1") - sut = NewCachingResolver(sutConfig, redisClient) + sut = NewCachingResolver(ctx, sutConfig, redisClient) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil) sut.Next(m) diff --git a/resolver/client_names_resolver.go b/resolver/client_names_resolver.go index 4dee1066..a369130f 100644 --- a/resolver/client_names_resolver.go +++ b/resolver/client_names_resolver.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "net" "strings" "time" @@ -26,7 +27,7 @@ type ClientNamesResolver struct { } // NewClientNamesResolver creates new resolver instance -func NewClientNamesResolver( +func NewClientNamesResolver(ctx context.Context, cfg config.ClientLookupConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool, ) (cr *ClientNamesResolver, err error) { var r Resolver @@ -41,7 +42,7 @@ func NewClientNamesResolver( configurable: withConfig(&cfg), typed: withType("client_names"), - cache: expirationcache.NewCache[[]string](expirationcache.Options{ + cache: expirationcache.NewCache[[]string](ctx, expirationcache.Options{ CleanupInterval: time.Hour, }), externalResolver: r, diff --git a/resolver/client_names_resolver_test.go b/resolver/client_names_resolver_test.go index 75374f59..dfe03fb0 100644 --- a/resolver/client_names_resolver_test.go +++ b/resolver/client_names_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "errors" "net" @@ -21,6 +22,8 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { sut *ClientNamesResolver sutConfig config.ClientLookupConfig m *mockResolver + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -31,8 +34,10 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { JustBeforeEach(func() { var err error + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) - sut, err = NewClientNamesResolver(sutConfig, nil, false) + sut, err = NewClientNamesResolver(ctx, sutConfig, nil, false) Expect(err).Should(Succeed()) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) @@ -361,9 +366,9 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { Describe("Connstruction", func() { When("upstream is invalid", func() { It("errors during construction", func() { - b := newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) + b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) - r, err := NewClientNamesResolver(config.ClientLookupConfig{ + r, err := NewClientNamesResolver(ctx, config.ClientLookupConfig{ Upstream: config.Upstream{Host: "example.com"}, }, b, true) diff --git a/resolver/conditional_upstream_resolver_test.go b/resolver/conditional_upstream_resolver_test.go index 38f40fa4..72f0cb2a 100644 --- a/resolver/conditional_upstream_resolver_test.go +++ b/resolver/conditional_upstream_resolver_test.go @@ -1,6 +1,8 @@ package resolver import ( + "context" + "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" @@ -184,7 +186,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu When("upstream is invalid", func() { It("errors during construction", func() { - b := newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) + ctx, cancelFn := context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) r, err := NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{ Mapping: config.ConditionalUpstreamMapping{ diff --git a/resolver/hosts_file_resolver.go b/resolver/hosts_file_resolver.go index be4b0eb7..3da3082d 100644 --- a/resolver/hosts_file_resolver.go +++ b/resolver/hosts_file_resolver.go @@ -34,7 +34,10 @@ type HostsFileResolver struct { downloader lists.FileDownloader } -func NewHostsFileResolver(cfg config.HostsFileConfig, bootstrap *Bootstrap) (*HostsFileResolver, error) { +func NewHostsFileResolver(ctx context.Context, + cfg config.HostsFileConfig, + bootstrap *Bootstrap, +) (*HostsFileResolver, error) { r := HostsFileResolver{ configurable: withConfig(&cfg), typed: withType("hosts_file"), @@ -42,7 +45,7 @@ func NewHostsFileResolver(cfg config.HostsFileConfig, bootstrap *Bootstrap) (*Ho downloader: lists.NewDownloader(cfg.Loading.Downloads, bootstrap.NewHTTPTransport()), } - err := cfg.Loading.StartPeriodicRefresh(r.loadSources, func(err error) { + err := cfg.Loading.StartPeriodicRefresh(ctx, r.loadSources, func(err error) { r.log().WithError(err).Errorf("could not load hosts files") }) if err != nil { diff --git a/resolver/hosts_file_resolver_test.go b/resolver/hosts_file_resolver_test.go index eef4591d..69341311 100644 --- a/resolver/hosts_file_resolver_test.go +++ b/resolver/hosts_file_resolver_test.go @@ -53,7 +53,9 @@ var _ = Describe("HostsFileResolver", func() { JustBeforeEach(func() { var err error - sut, err = NewHostsFileResolver(sutConfig, systemResolverBootstrap) + ctx, cancelFn := context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sut, err = NewHostsFileResolver(ctx, sutConfig, systemResolverBootstrap) Expect(err).Should(Succeed()) m = &mockResolver{} diff --git a/resolver/mocks_test.go b/resolver/mocks_test.go index 6769aafa..48fa2586 100644 --- a/resolver/mocks_test.go +++ b/resolver/mocks_test.go @@ -118,10 +118,10 @@ func autoAnswer(qType dns.Type, qName string) (*dns.Msg, error) { } // newTestBootstrap creates a test Bootstrap -func newTestBootstrap(response *dns.Msg) *Bootstrap { +func newTestBootstrap(ctx context.Context, response *dns.Msg) *Bootstrap { bootstrapUpstream := &mockResolver{} - b, err := NewBootstrap(&config.Config{}) + b, err := NewBootstrap(ctx, &config.Config{}) util.FatalOnError("can't create bootstrap", err) b.resolver = bootstrapUpstream diff --git a/resolver/parallel_best_resolver_test.go b/resolver/parallel_best_resolver_test.go index 0278be84..3c25b69c 100644 --- a/resolver/parallel_best_resolver_test.go +++ b/resolver/parallel_best_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "strings" "time" @@ -24,6 +25,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { sut *ParallelBestResolver sutMapping config.UpstreamGroups sutVerify bool + ctx context.Context + cancelFn context.CancelFunc err error @@ -37,6 +40,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sutMapping = config.UpstreamGroups{ upstreamDefaultCfgName: { config.Upstream{ @@ -111,7 +117,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { When("no upstream resolvers can be reached", func() { BeforeEach(func() { - bootstrap = newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) + bootstrap = newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) sutMapping = config.UpstreamGroups{ upstreamDefaultCfgName: { @@ -328,7 +334,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { When("upstream is invalid", func() { It("errors during construction", func() { - b := newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) + b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) r, err := NewParallelBestResolver(config.UpstreamsConfig{ Groups: config.UpstreamGroups{"test": {{Host: "example.com"}}}, diff --git a/resolver/query_logging_resolver.go b/resolver/query_logging_resolver.go index 6edd6abb..170101eb 100644 --- a/resolver/query_logging_resolver.go +++ b/resolver/query_logging_resolver.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "time" "github.com/0xERR0R/blocky/config" @@ -29,7 +30,7 @@ type QueryLoggingResolver struct { } // NewQueryLoggingResolver returns a new resolver instance -func NewQueryLoggingResolver(cfg config.QueryLogConfig) *QueryLoggingResolver { +func NewQueryLoggingResolver(ctx context.Context, cfg config.QueryLogConfig) *QueryLoggingResolver { logger := log.PrefixedLog(queryLoggingResolverType) var writer querylog.Writer @@ -43,10 +44,10 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) *QueryLoggingResolver { case config.QueryLogTypeCsvClient: writer, err = querylog.NewCSVWriter(cfg.Target, true, cfg.LogRetentionDays) case config.QueryLogTypeMysql: - writer, err = querylog.NewDatabaseWriter("mysql", cfg.Target, cfg.LogRetentionDays, + writer, err = querylog.NewDatabaseWriter(ctx, "mysql", cfg.Target, cfg.LogRetentionDays, cfg.FlushInterval.ToDuration()) case config.QueryLogTypePostgresql: - writer, err = querylog.NewDatabaseWriter("postgresql", cfg.Target, cfg.LogRetentionDays, + writer, err = querylog.NewDatabaseWriter(ctx, "postgresql", cfg.Target, cfg.LogRetentionDays, cfg.FlushInterval.ToDuration()) case config.QueryLogTypeConsole: writer = querylog.NewLoggerWriter() @@ -81,23 +82,27 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) *QueryLoggingResolver { writer: writer, } - go resolver.writeLog() + go resolver.writeLog(ctx) if cfg.LogRetentionDays > 0 { - go resolver.periodicCleanUp() + go resolver.periodicCleanUp(ctx) } return &resolver } // triggers periodically cleanup of old log files -func (r *QueryLoggingResolver) periodicCleanUp() { +func (r *QueryLoggingResolver) periodicCleanUp(ctx context.Context) { ticker := time.NewTicker(cleanUpRunPeriod) defer ticker.Stop() for { - <-ticker.C - r.doCleanUp() + select { + case <-ticker.C: + r.doCleanUp() + case <-ctx.Done(): + return + } } } @@ -164,18 +169,23 @@ func (r *QueryLoggingResolver) createLogEntry(request *model.Request, response * } // write entry: if log directory is configured, write to log file -func (r *QueryLoggingResolver) writeLog() { - for logEntry := range r.logChan { - start := time.Now() +func (r *QueryLoggingResolver) writeLog(ctx context.Context) { + for { + select { + case logEntry := <-r.logChan: + start := time.Now() - r.writer.Write(logEntry) + r.writer.Write(logEntry) - halfCap := cap(r.logChan) / 2 //nolint:gomnd + halfCap := cap(r.logChan) / 2 //nolint:gomnd - // if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.) - if len(r.logChan) > halfCap { - r.log().WithField("channel_len", - len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds()) + // if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.) + if len(r.logChan) > halfCap { + r.log().WithField("channel_len", + len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds()) + } + case <-ctx.Done(): + return } } } diff --git a/resolver/query_logging_resolver_test.go b/resolver/query_logging_resolver_test.go index 7386654d..efc33bf5 100644 --- a/resolver/query_logging_resolver_test.go +++ b/resolver/query_logging_resolver_test.go @@ -2,6 +2,7 @@ package resolver import ( "bufio" + "context" "encoding/csv" "errors" "fmt" @@ -63,8 +64,10 @@ var _ = Describe("QueryLoggingResolver", func() { sutConfig.SetDefaults() // not called when using a struct literal } - sut = NewQueryLoggingResolver(sutConfig) - DeferCleanup(func() { close(sut.logChan) }) + ctx, cancelFn := context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + + sut = NewQueryLoggingResolver(ctx, sutConfig) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil) sut.Next(m) @@ -151,7 +154,7 @@ var _ = Describe("QueryLoggingResolver", func() { g.Expect(csvLines[0][7]).Should(Equal("NOERROR")) g.Expect(csvLines[0][8]).Should(Equal("RESOLVED")) g.Expect(csvLines[0][9]).Should(Equal("A")) - }, "1s").Should(Succeed()) + }).Should(Succeed()) }) By("check log for client2", func() { @@ -169,7 +172,7 @@ var _ = Describe("QueryLoggingResolver", func() { g.Expect(csvLines[0][7]).Should(Equal("NOERROR")) g.Expect(csvLines[0][8]).Should(Equal("RESOLVED")) g.Expect(csvLines[0][9]).Should(Equal("A")) - }, "1s").Should(Succeed()) + }).Should(Succeed()) }) }) }) diff --git a/resolver/resolver_test.go b/resolver/resolver_test.go index 05045ead..5e703a66 100644 --- a/resolver/resolver_test.go +++ b/resolver/resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "strings" "github.com/0xERR0R/blocky/config" @@ -109,16 +110,24 @@ var _ = Describe("Resolver", func() { }) Describe("Name", func() { + var ( + ctx context.Context + cancelFn context.CancelFunc + ) + BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + }) When("'Name' is called", func() { It("should return resolver name", func() { - br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap) + br, _ := NewBlockingResolver(ctx, config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap) name := Name(br) Expect(name).Should(Equal("blocking")) }) }) When("'Name' is called on a NamedResolver", func() { It("should return its custom name", func() { - br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap) + br, _ := NewBlockingResolver(ctx, config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap) cfg := config.RewriterConfig{Rewrite: map[string]string{"not": "empty"}} r := NewRewriterResolver(cfg, br) diff --git a/server/server.go b/server/server.go index 9561a160..e5311f82 100644 --- a/server/server.go +++ b/server/server.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -113,7 +114,7 @@ func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) { // NewServer creates new server instance with passed config // //nolint:funlen -func NewServer(cfg *config.Config) (server *Server, err error) { +func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) { log.ConfigureLogger(&cfg.Log) var cert tls.Certificate @@ -145,7 +146,7 @@ func NewServer(cfg *config.Config) (server *Server, err error) { metrics.RegisterEventListeners() - bootstrap, err := resolver.NewBootstrap(cfg) + bootstrap, err := resolver.NewBootstrap(ctx, cfg) if err != nil { return nil, err } @@ -155,7 +156,7 @@ func NewServer(cfg *config.Config) (server *Server, err error) { return nil, redisErr } - queryResolver, queryError := createQueryResolver(cfg, bootstrap, redisClient) + queryResolver, queryError := createQueryResolver(ctx, cfg, bootstrap, redisClient) if queryError != nil { return nil, queryError } @@ -385,6 +386,7 @@ func createSelfSignedCert() (tls.Certificate, error) { } func createQueryResolver( + ctx context.Context, cfg *config.Config, bootstrap *resolver.Bootstrap, redisClient *redis.Client, @@ -396,10 +398,10 @@ func createQueryResolver( upstreamTree, utErr := resolver.NewUpstreamTreeResolver(cfg.Upstreams, upstreamBranches) - blocking, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap) - clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream) + blocking, blErr := resolver.NewBlockingResolver(ctx, cfg.Blocking, redisClient, bootstrap) + clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream) condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream) - hostsFile, hfErr := resolver.NewHostsFileResolver(cfg.HostsFile, bootstrap) + hostsFile, hfErr := resolver.NewHostsFileResolver(ctx, cfg.HostsFile, bootstrap) err = multierror.Append( multierror.Prefix(utErr, "upstream tree resolver: "), @@ -417,12 +419,12 @@ func createQueryResolver( resolver.NewFqdnOnlyResolver(cfg.FqdnOnly), clientNames, resolver.NewEdeResolver(cfg.Ede), - resolver.NewQueryLoggingResolver(cfg.QueryLog), + resolver.NewQueryLoggingResolver(ctx, cfg.QueryLog), resolver.NewMetricsResolver(cfg.Prometheus), resolver.NewRewriterResolver(cfg.CustomDNS.RewriterConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)), hostsFile, blocking, - resolver.NewCachingResolver(cfg.Caching, redisClient), + resolver.NewCachingResolver(ctx, cfg.Caching, redisClient), resolver.NewRewriterResolver(cfg.Conditional.RewriterConfig, condUpstream), resolver.NewSpecialUseDomainNamesResolver(cfg.SUDN), upstreamTree, @@ -513,7 +515,7 @@ const ( ) // Start starts the server -func (s *Server) Start(errCh chan<- error) { +func (s *Server) Start(ctx context.Context, errCh chan<- error) { logger().Info("Starting server") for _, srv := range s.dnsServers { @@ -572,7 +574,7 @@ func (s *Server) Start(errCh chan<- error) { }() } - registerPrintConfigurationTrigger(s) + registerPrintConfigurationTrigger(ctx, s) } // Stop stops the server diff --git a/server/server_config_trigger.go b/server/server_config_trigger.go index 4867ccea..8903c1ca 100644 --- a/server/server_config_trigger.go +++ b/server/server_config_trigger.go @@ -4,19 +4,25 @@ package server import ( + "context" "os" "os/signal" "syscall" ) -func registerPrintConfigurationTrigger(s *Server) { +func registerPrintConfigurationTrigger(ctx context.Context, s *Server) { signals := make(chan os.Signal, 1) signal.Notify(signals, syscall.SIGUSR1) go func() { for { - <-signals - s.printConfiguration() + select { + case <-signals: + s.printConfiguration() + + case <-ctx.Done(): + return + } } }() } diff --git a/server/server_test.go b/server/server_test.go index fa5275fd..a346655c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "context" "encoding/base64" "io" "net" @@ -32,6 +33,8 @@ var ( var _ = BeforeSuite(func() { var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream + ctx, cancelFn := context.WithCancel(context.Background()) + DeferCleanup(cancelFn) googleMockUpstream := resolver.NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { if request.Question[0].Name == "error." { return nil @@ -102,7 +105,7 @@ var _ = BeforeSuite(func() { Expect(youtubeFile.Error).Should(Succeed()) // create server - sut, err = NewServer(&config.Config{ + sut, err = NewServer(ctx, &config.Config{ CustomDNS: config.CustomDNSConfig{ CustomTTL: config.Duration(3600 * time.Second), Mapping: config.CustomDNSMapping{ @@ -169,7 +172,7 @@ var _ = BeforeSuite(func() { // start server go func() { - sut.Start(errChan) + sut.Start(ctx, errChan) }() DeferCleanup(sut.Stop) @@ -177,6 +180,14 @@ var _ = BeforeSuite(func() { }) var _ = Describe("Running DNS server", func() { + var ( + ctx context.Context + cancelFn context.CancelFunc + ) + BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + }) Describe("performing DNS request with running server", func() { BeforeEach(func() { mockClientName.Store("") @@ -555,14 +566,14 @@ var _ = Describe("Running DNS server", func() { }) When("Server is created", func() { It("is created without redis connection", func() { - _, err := NewServer(&cfg) + _, err := NewServer(ctx, &cfg) Expect(err).Should(Succeed()) }) It("can't be created if redis server is unavailable", func() { cfg.Redis.Required = true - _, err := NewServer(&cfg) + _, err := NewServer(ctx, &cfg) Expect(err).ShouldNot(Succeed()) }) @@ -573,7 +584,7 @@ var _ = Describe("Running DNS server", func() { When("Server start is called", func() { It("start was called 2 times, start should fail", func() { // create server - server, err := NewServer(&config.Config{ + server, err := NewServer(ctx, &config.Config{ Upstreams: config.UpstreamsConfig{ Groups: map[string][]config.Upstream{ "default": {config.Upstream{Net: config.NetProtocolTcpUdp, Host: "4.4.4.4", Port: 53}}, @@ -598,14 +609,14 @@ var _ = Describe("Running DNS server", func() { errChan := make(chan error, 10) // start server - go server.Start(errChan) + go server.Start(ctx, errChan) DeferCleanup(server.Stop) Consistently(errChan, "1s").ShouldNot(Receive()) // start again -> should fail - server.Start(errChan) + server.Start(ctx, errChan) Eventually(errChan).Should(Receive()) }) @@ -615,7 +626,7 @@ var _ = Describe("Running DNS server", func() { When("Stop is called", func() { It("stop was called 2 times, start should fail", func() { // create server - server, err := NewServer(&config.Config{ + server, err := NewServer(ctx, &config.Config{ Upstreams: config.UpstreamsConfig{ Groups: map[string][]config.Upstream{ "default": {config.Upstream{Net: config.NetProtocolTcpUdp, Host: "4.4.4.4", Port: 53}}, @@ -641,7 +652,7 @@ var _ = Describe("Running DNS server", func() { // start server go func() { - server.Start(errChan) + server.Start(ctx, errChan) }() time.Sleep(100 * time.Millisecond) @@ -681,7 +692,7 @@ var _ = Describe("Running DNS server", func() { Describe("create query resolver", func() { When("some upstream returns error", func() { It("create query resolver should return error", func() { - r, err := createQueryResolver(&config.Config{ + r, err := createQueryResolver(ctx, &config.Config{ StartVerifyUpstream: true, Upstreams: config.UpstreamsConfig{ Groups: config.UpstreamGroups{ @@ -736,8 +747,7 @@ var _ = Describe("Running DNS server", func() { cfg.Ports = config.PortsConfig{ HTTPS: []string{":14443"}, } - - sut, err := NewServer(&cfg) + sut, err := NewServer(ctx, &cfg) Expect(err).Should(Succeed()) Expect(sut.cert.Certificate).ShouldNot(BeNil()) })