refactor: pass context for goroutine shutdown (#1187)

This commit is contained in:
Dimitri Herzog 2023-10-07 22:21:40 +02:00 committed by GitHub
parent d9e91da686
commit 33ea933015
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 349 additions and 177 deletions

View File

@ -1,6 +1,7 @@
package expirationcache
import (
"context"
"time"
lru "github.com/hashicorp/golang-lru"
@ -47,11 +48,11 @@ type OnCacheMissCallback func(key string)
// OnAfterPutCallback will be called after put, receives new element count as parameter
type OnAfterPutCallback func(newSize int)
func NewCache[T any](options Options) *ExpiringLRUCache[T] {
return NewCacheWithOnExpired[T](options, nil)
func NewCache[T any](ctx context.Context, options Options) *ExpiringLRUCache[T] {
return NewCacheWithOnExpired[T](ctx, options, nil)
}
func NewCacheWithOnExpired[T any](options Options,
func NewCacheWithOnExpired[T any](ctx context.Context, options Options,
onExpirationFn OnExpirationCallback[T],
) *ExpiringLRUCache[T] {
l, _ := lru.New(defaultSize)
@ -90,18 +91,22 @@ func NewCacheWithOnExpired[T any](options Options,
c.preExpirationFn = onExpirationFn
}
go periodicCleanup(c)
go periodicCleanup(ctx, c)
return c
}
func periodicCleanup[T any](c *ExpiringLRUCache[T]) {
func periodicCleanup[T any](ctx context.Context, c *ExpiringLRUCache[T]) {
ticker := time.NewTicker(c.cleanUpInterval)
defer ticker.Stop()
for {
<-ticker.C
c.cleanUp()
select {
case <-ticker.C:
c.cleanUp()
case <-ctx.Done():
return
}
}
}

View File

@ -1,6 +1,7 @@
package expirationcache
import (
"context"
"time"
. "github.com/onsi/ginkgo/v2"
@ -8,14 +9,22 @@ import (
)
var _ = Describe("Expiration cache", func() {
var (
ctx context.Context
cancelFn context.CancelFunc
)
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
})
Describe("Basic operations", func() {
When("string cache was created", func() {
It("Initial cache should be empty", func() {
cache := NewCache[string](Options{})
cache := NewCache[string](ctx, Options{})
Expect(cache.TotalCount()).Should(Equal(0))
})
It("Initial cache should not contain any elements", func() {
cache := NewCache[string](Options{})
cache := NewCache[string](ctx, Options{})
val, expiration := cache.Get("key1")
Expect(val).Should(BeNil())
Expect(expiration).Should(Equal(time.Duration(0)))
@ -23,7 +32,7 @@ var _ = Describe("Expiration cache", func() {
})
When("Put new value with positive TTL", func() {
It("Should return the value before element expires", func() {
cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond})
cache := NewCache[string](ctx, Options{CleanupInterval: 100 * time.Millisecond})
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
val, expiration := cache.Get("key1")
@ -33,7 +42,7 @@ var _ = Describe("Expiration cache", func() {
Expect(cache.TotalCount()).Should(Equal(1))
})
It("Should return nil after expiration", func() {
cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond})
cache := NewCache[string](ctx, Options{CleanupInterval: 100 * time.Millisecond})
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
@ -47,12 +56,12 @@ var _ = Describe("Expiration cache", func() {
// wait for cleanup run
Eventually(func() int {
return cache.lru.Len()
}, "100ms").Should(Equal(0))
}).Should(Equal(0))
})
})
When("Put new value without expiration", func() {
It("Should not cache the value", func() {
cache := NewCache[string](Options{CleanupInterval: 50 * time.Millisecond})
cache := NewCache[string](ctx, Options{CleanupInterval: 50 * time.Millisecond})
v := "x"
cache.Put("key1", &v, 0)
val, expiration := cache.Get("key1")
@ -63,7 +72,7 @@ var _ = Describe("Expiration cache", func() {
})
When("Put updated value", func() {
It("Should return updated value", func() {
cache := NewCache[string](Options{})
cache := NewCache[string](ctx, Options{})
v1 := "v1"
v2 := "v2"
cache.Put("key1", &v1, 50*time.Millisecond)
@ -79,7 +88,7 @@ var _ = Describe("Expiration cache", func() {
})
When("Purging after usage", func() {
It("Should be empty after purge", func() {
cache := NewCache[string](Options{})
cache := NewCache[string](ctx, Options{})
v1 := "y"
cache.Put("key1", &v1, time.Second)
@ -97,7 +106,7 @@ var _ = Describe("Expiration cache", func() {
onCacheHitChannel := make(chan string, 10)
onCacheMissChannel := make(chan string, 10)
onAfterPutChannel := make(chan int, 10)
cache := NewCache[string](Options{
cache := NewCache[string](ctx, Options{
OnCacheHitFn: func(key string) {
onCacheHitChannel <- key
},
@ -146,7 +155,7 @@ var _ = Describe("Expiration cache", func() {
return &v2, time.Second
}
cache := NewCacheWithOnExpired[string](Options{}, fn)
cache := NewCacheWithOnExpired[string](ctx, Options{}, fn)
v1 := "v1"
cache.Put("key1", &v1, 50*time.Millisecond)
@ -165,7 +174,7 @@ var _ = Describe("Expiration cache", func() {
return &v2, time.Second
}
cache := NewCacheWithOnExpired[string](Options{}, fn)
cache := NewCacheWithOnExpired[string](ctx, Options{}, fn)
v1 := "somval"
cache.Put("key1", &v1, time.Millisecond)
@ -186,7 +195,7 @@ var _ = Describe("Expiration cache", func() {
fn := func(key string) (val *string, ttl time.Duration) {
return nil, 0
}
cache := NewCacheWithOnExpired[string](Options{CleanupInterval: 100 * time.Microsecond}, fn)
cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn)
v1 := "z"
cache.Put("key1", &v1, 50*time.Millisecond)
@ -199,7 +208,7 @@ var _ = Describe("Expiration cache", func() {
Describe("LRU behaviour", func() {
When("Defined max size is reached", func() {
It("should remove old elements", func() {
cache := NewCache[string](Options{MaxSize: 3})
cache := NewCache[string](ctx, Options{MaxSize: 3})
v1 := "val1"
v2 := "val2"

View File

@ -1,6 +1,7 @@
package expirationcache
import (
"context"
"sync/atomic"
"time"
)
@ -39,9 +40,9 @@ type PrefetchingOptions[T any] struct {
type PrefetchingCacheOption[T any] func(c *PrefetchingExpiringLRUCache[cacheValue[T]])
func NewPrefetchingCache[T any](options PrefetchingOptions[T]) *PrefetchingExpiringLRUCache[T] {
func NewPrefetchingCache[T any](ctx context.Context, options PrefetchingOptions[T]) *PrefetchingExpiringLRUCache[T] {
pc := &PrefetchingExpiringLRUCache[T]{
prefetchingNameCache: NewCache[atomic.Uint32](Options{
prefetchingNameCache: NewCache[atomic.Uint32](ctx, Options{
CleanupInterval: time.Minute,
MaxSize: uint(options.PrefetchMaxItemsCount),
OnAfterPutFn: options.OnPrefetchAfterPut,
@ -53,7 +54,7 @@ func NewPrefetchingCache[T any](options PrefetchingOptions[T]) *PrefetchingExpir
onPrefetchCacheHit: options.OnPrefetchCacheHit,
}
pc.cache = NewCacheWithOnExpired[cacheValue[T]](options.Options, pc.onExpired)
pc.cache = NewCacheWithOnExpired[cacheValue[T]](ctx, options.Options, pc.onExpired)
return pc
}

View File

@ -1,6 +1,7 @@
package expirationcache
import (
"context"
"time"
. "github.com/onsi/ginkgo/v2"
@ -8,21 +9,29 @@ import (
)
var _ = Describe("Prefetching expiration cache", func() {
var (
ctx context.Context
cancelFn context.CancelFunc
)
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
})
Describe("Basic operations", func() {
When("string cache was created", func() {
It("Initial cache should be empty", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{})
Expect(cache.TotalCount()).Should(Equal(0))
})
It("Initial cache should not contain any elements", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{})
val, expiration := cache.Get("key1")
Expect(val).Should(BeNil())
Expect(expiration).Should(Equal(time.Duration(0)))
})
It("Should work as cache (basic operations)", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{})
v := "v1"
cache.Put("key1", &v, 50*time.Millisecond)
@ -39,7 +48,7 @@ var _ = Describe("Prefetching expiration cache", func() {
})
Context("Prefetching", func() {
It("Should prefetch element", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
@ -71,7 +80,7 @@ var _ = Describe("Prefetching expiration cache", func() {
})
})
It("Should not prefetch element", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
@ -100,7 +109,7 @@ var _ = Describe("Prefetching expiration cache", func() {
})
})
It("With default config (threshold = 0) should always prefetch", func() {
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
@ -128,7 +137,7 @@ var _ = Describe("Prefetching expiration cache", func() {
onPrefetchAfterPutChannel := make(chan int, 10)
onPrefetchEntryReloaded := make(chan string, 10)
onnPrefetchCacheHit := make(chan string, 10)
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},

View File

@ -1,6 +1,7 @@
package cmd
import (
"context"
"fmt"
"os"
"os/signal"
@ -43,7 +44,10 @@ func startServer(_ *cobra.Command, _ []string) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
srv, err := server.NewServer(cfg)
ctx, cancelFn := context.WithCancel(context.Background())
defer cancelFn()
srv, err := server.NewServer(ctx, cfg)
if err != nil {
return fmt.Errorf("can't start server: %w", err)
}
@ -51,7 +55,7 @@ func startServer(_ *cobra.Command, _ []string) error {
const errChanSize = 10
errChan := make(chan error, errChanSize)
srv.Start(errChan)
srv.Start(ctx, errChan)
var terminationErr error

View File

@ -93,7 +93,7 @@ var _ = Describe("Serve command", func() {
By("terminate with signal", func() {
var startError error
Eventually(errChan).Should(Receive(&startError))
Eventually(errChan, "10s").Should(Receive(&startError))
Expect(startError).ShouldNot(BeNil())
Expect(startError.Error()).Should(ContainSubstring("address already in use"))
})

View File

@ -314,7 +314,10 @@ func (c *SourceLoadingConfig) LogConfig(logger *logrus.Entry) {
log.WithIndent(logger, " ", c.Downloads.LogConfig)
}
func (c *SourceLoadingConfig) StartPeriodicRefresh(refresh func(context.Context) error, logErr func(error)) error {
func (c *SourceLoadingConfig) StartPeriodicRefresh(ctx context.Context,
refresh func(context.Context) error,
logErr func(error),
) error {
refreshAndRecover := func(ctx context.Context) (rerr error) {
defer func() {
if val := recover(); val != nil {
@ -331,20 +334,29 @@ func (c *SourceLoadingConfig) StartPeriodicRefresh(refresh func(context.Context)
}
if c.RefreshPeriod > 0 {
go c.periodically(refreshAndRecover, logErr)
go c.periodically(ctx, refreshAndRecover, logErr)
}
return nil
}
func (c *SourceLoadingConfig) periodically(refresh func(context.Context) error, logErr func(error)) {
func (c *SourceLoadingConfig) periodically(ctx context.Context,
refresh func(context.Context) error,
logErr func(error),
) {
ticker := time.NewTicker(c.RefreshPeriod.ToDuration())
defer ticker.Stop()
for range ticker.C {
err := refresh(context.Background())
if err != nil {
logErr(err)
for {
select {
case <-ticker.C:
err := refresh(ctx)
if err != nil {
logErr(err)
}
case <-ctx.Done():
return
}
}
}

View File

@ -688,6 +688,14 @@ bootstrapDns:
})
Describe("SourceLoadingConfig", func() {
var (
ctx context.Context
cancelFn context.CancelFunc
)
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
})
It("handles panics", func() {
sut := SourceLoadingConfig{
Strategy: StartStrategyTypeFailOnError,
@ -695,7 +703,7 @@ bootstrapDns:
panicMsg := "panic value"
err := sut.StartPeriodicRefresh(func(context.Context) error {
err := sut.StartPeriodicRefresh(ctx, func(context.Context) error {
panic(panicMsg)
}, func(err error) {
Expect(err).Should(MatchError(ContainSubstring(panicMsg)))
@ -715,7 +723,7 @@ bootstrapDns:
var call atomic.Int32
err := sut.StartPeriodicRefresh(func(context.Context) error {
err := sut.StartPeriodicRefresh(ctx, func(context.Context) error {
call := call.Add(1)
calls <- call

View File

@ -56,7 +56,7 @@ func (b *ListCache) LogConfig(logger *logrus.Entry) {
}
// NewListCache creates new list instance
func NewListCache(
func NewListCache(ctx context.Context,
t ListCacheType, cfg config.SourceLoadingConfig,
groupSources map[string][]config.BytesSource, downloader FileDownloader,
) (*ListCache, error) {
@ -72,7 +72,7 @@ func NewListCache(
downloader: downloader,
}
err := cfg.StartPeriodicRefresh(c.refresh, func(err error) {
err := cfg.StartPeriodicRefresh(ctx, c.refresh, func(err error) {
logger().WithError(err).Errorf("could not init %s", t)
})
if err != nil {

View File

@ -1,6 +1,7 @@
package lists
import (
"context"
"testing"
"github.com/0xERR0R/blocky/config"
@ -19,7 +20,7 @@ func BenchmarkRefresh(b *testing.B) {
RefreshPeriod: config.Duration(-1),
}
downloader := NewDownloader(config.DownloaderConfig{}, nil)
cache, _ := NewListCache(ListCacheTypeBlacklist, cfg, lists, downloader)
cache, _ := NewListCache(context.Background(), ListCacheTypeBlacklist, cfg, lists, downloader)
b.ReportAllocs()

View File

@ -36,11 +36,16 @@ var _ = Describe("ListCache", func() {
lists map[string][]config.BytesSource
downloader FileDownloader
mockDownloader *MockDownloader
ctx context.Context
cancelFn context.CancelFunc
)
BeforeEach(func() {
var err error
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
listCacheType = ListCacheTypeBlacklist
sutConfig, err = config.WithDefaults[config.SourceLoadingConfig]()
@ -84,7 +89,7 @@ var _ = Describe("ListCache", func() {
downloader = mockDownloader
}
sut, err = NewListCache(listCacheType, sutConfig, lists, downloader)
sut, err = NewListCache(ctx, listCacheType, sutConfig, lists, downloader)
Expect(err).Should(Succeed())
})
@ -304,7 +309,7 @@ var _ = Describe("ListCache", func() {
"gr1": config.NewBytesSources(file1, file2, file3),
}
sut, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
sut, err := NewListCache(ctx, ListCacheTypeBlacklist, sutConfig, lists, downloader)
Expect(err).Should(Succeed())
Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(lines1 + lines2 + lines3))
@ -359,7 +364,7 @@ var _ = Describe("ListCache", func() {
},
}
_, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
_, err := NewListCache(ctx, ListCacheTypeBlacklist, sutConfig, lists, downloader)
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(parsers.ErrTooManyErrors))
})
@ -408,7 +413,7 @@ var _ = Describe("ListCache", func() {
"gr2": {config.TextBytesSource("inline", "definition")},
}
sut, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
sut, err := NewListCache(ctx, ListCacheTypeBlacklist, sutConfig, lists, downloader)
Expect(err).Should(Succeed())
sut.LogConfig(logger)
@ -430,7 +435,7 @@ var _ = Describe("ListCache", func() {
"gr1": config.NewBytesSources("doesnotexist"),
}
_, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
_, err := NewListCache(ctx, ListCacheTypeBlacklist, sutConfig, lists, downloader)
Expect(err).Should(Succeed())
})
})

View File

@ -1,6 +1,7 @@
package querylog
import (
"context"
"fmt"
"reflect"
"strings"
@ -43,20 +44,20 @@ type DatabaseWriter struct {
dbFlushPeriod time.Duration
}
func NewDatabaseWriter(dbType, target string, logRetentionDays uint64,
func NewDatabaseWriter(ctx context.Context, dbType, target string, logRetentionDays uint64,
dbFlushPeriod time.Duration,
) (*DatabaseWriter, error) {
switch dbType {
case "mysql":
return newDatabaseWriter(mysql.Open(target), logRetentionDays, dbFlushPeriod)
return newDatabaseWriter(ctx, mysql.Open(target), logRetentionDays, dbFlushPeriod)
case "postgresql":
return newDatabaseWriter(postgres.Open(target), logRetentionDays, dbFlushPeriod)
return newDatabaseWriter(ctx, postgres.Open(target), logRetentionDays, dbFlushPeriod)
}
return nil, fmt.Errorf("incorrect database type provided: %s", dbType)
}
func newDatabaseWriter(target gorm.Dialector, logRetentionDays uint64,
func newDatabaseWriter(ctx context.Context, target gorm.Dialector, logRetentionDays uint64,
dbFlushPeriod time.Duration,
) (*DatabaseWriter, error) {
db, err := gorm.Open(target, &gorm.Config{
@ -84,7 +85,7 @@ func newDatabaseWriter(target gorm.Dialector, logRetentionDays uint64,
dbFlushPeriod: dbFlushPeriod,
}
go w.periodicFlush()
go w.periodicFlush(ctx)
return w, nil
}
@ -118,16 +119,20 @@ func databaseMigration(db *gorm.DB) error {
return nil
}
func (d *DatabaseWriter) periodicFlush() {
func (d *DatabaseWriter) periodicFlush(ctx context.Context) {
ticker := time.NewTicker(d.dbFlushPeriod)
defer ticker.Stop()
for {
<-ticker.C
select {
case <-ticker.C:
err := d.doDBWrite()
err := d.doDBWrite()
util.LogOnError("can't write entries to the database: ", err)
util.LogOnError("can't write entries to the database: ", err)
case <-ctx.Done():
return
}
}
}

View File

@ -1,6 +1,7 @@
package querylog
import (
"context"
"database/sql"
"fmt"
"time"
@ -20,19 +21,26 @@ import (
var err error
var _ = Describe("DatabaseWriter", func() {
var (
ctx context.Context
cancelFn context.CancelFunc
)
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
})
Describe("Database query log to sqlite", func() {
var (
sqliteDB gorm.Dialector
writer *DatabaseWriter
)
BeforeEach(func() {
sqliteDB = sqlite.Open("file::memory:")
})
When("New log entry was created", func() {
BeforeEach(func() {
writer, err = newDatabaseWriter(sqliteDB, 7, time.Millisecond)
writer, err = newDatabaseWriter(ctx, sqliteDB, 7, time.Millisecond)
Expect(err).Should(Succeed())
db, err := writer.db.DB()
@ -83,7 +91,7 @@ var _ = Describe("DatabaseWriter", func() {
When("> 10000 Entries were created", func() {
BeforeEach(func() {
writer, err = newDatabaseWriter(sqliteDB, 7, time.Millisecond)
writer, err = newDatabaseWriter(ctx, sqliteDB, 7, time.Millisecond)
Expect(err).Should(Succeed())
})
@ -114,7 +122,7 @@ var _ = Describe("DatabaseWriter", func() {
When("There are log entries with timestamp exceeding the retention period", func() {
BeforeEach(func() {
writer, err = newDatabaseWriter(sqliteDB, 1, time.Millisecond)
writer, err = newDatabaseWriter(ctx, sqliteDB, 1, time.Millisecond)
Expect(err).Should(Succeed())
})
@ -162,7 +170,7 @@ var _ = Describe("DatabaseWriter", func() {
Describe("Database query log fails", func() {
When("mysql connection parameters wrong", func() {
It("should be log with fatal", func() {
_, err := NewDatabaseWriter("mysql", "wrong param", 7, 1)
_, err := NewDatabaseWriter(ctx, "mysql", "wrong param", 7, 1)
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(HavePrefix("can't create database connection"))
})
@ -170,7 +178,7 @@ var _ = Describe("DatabaseWriter", func() {
When("postgresql connection parameters wrong", func() {
It("should be log with fatal", func() {
_, err := NewDatabaseWriter("postgresql", "wrong param", 7, 1)
_, err := NewDatabaseWriter(ctx, "postgresql", "wrong param", 7, 1)
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(HavePrefix("can't create database connection"))
})
@ -178,7 +186,7 @@ var _ = Describe("DatabaseWriter", func() {
When("invalid database type is specified", func() {
It("should be log with fatal", func() {
_, err := NewDatabaseWriter("invalidsql", "", 7, 1)
_, err := NewDatabaseWriter(ctx, "invalidsql", "", 7, 1)
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(HavePrefix("incorrect database type provided"))
})
@ -225,7 +233,7 @@ var _ = Describe("DatabaseWriter", func() {
mock.ExpectExec(`ALTER TABLE log_entries ADD column if not exists id serial primary key`).WillReturnResult(sqlmock.NewResult(0, 0))
})
_, err = newDatabaseWriter(dlc, 1, time.Millisecond)
_, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond)
Expect(err).Should(Succeed())
})
})
@ -257,7 +265,7 @@ var _ = Describe("DatabaseWriter", func() {
mock.ExpectExec("ALTER TABLE `log_entries` ADD `id` INT PRIMARY KEY AUTO_INCREMENT").WillReturnResult(sqlmock.NewResult(0, 0))
})
_, err = newDatabaseWriter(dlc, 1, time.Millisecond)
_, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond)
Expect(err).Should(Succeed())
})
})
@ -273,7 +281,7 @@ var _ = Describe("DatabaseWriter", func() {
mock.ExpectExec("ALTER TABLE `log_entries` ADD `id` INT PRIMARY KEY AUTO_INCREMENT").WillReturnError(fmt.Errorf("error 1060: duplicate column name"))
})
_, err = newDatabaseWriter(dlc, 1, time.Millisecond)
_, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond)
Expect(err).Should(Succeed())
})
@ -286,7 +294,7 @@ var _ = Describe("DatabaseWriter", func() {
mock.ExpectExec("ALTER TABLE `log_entries` ADD `id` INT PRIMARY KEY AUTO_INCREMENT").WillReturnError(fmt.Errorf("error XXX: some index error"))
})
_, err = newDatabaseWriter(dlc, 1, time.Millisecond)
_, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond)
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("can't perform auto migration: error XXX: some index error"))
})
@ -298,7 +306,7 @@ var _ = Describe("DatabaseWriter", func() {
mock.ExpectExec("CREATE TABLE `log_entries`").WillReturnError(fmt.Errorf("error XXX: some db error"))
})
_, err = newDatabaseWriter(dlc, 1, time.Millisecond)
_, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond)
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("can't perform auto migration: error XXX: some db error"))
})

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"fmt"
"net"
"sort"
@ -110,8 +111,10 @@ func clientGroupsBlock(cfg config.BlockingConfig) map[string][]string {
}
// NewBlockingResolver returns a new configured instance of the resolver
func NewBlockingResolver(
cfg config.BlockingConfig, redis *redis.Client, bootstrap *Bootstrap,
func NewBlockingResolver(ctx context.Context,
cfg config.BlockingConfig,
redis *redis.Client,
bootstrap *Bootstrap,
) (r *BlockingResolver, err error) {
blockHandler, err := createBlockHandler(cfg)
if err != nil {
@ -120,8 +123,10 @@ func NewBlockingResolver(
downloader := lists.NewDownloader(cfg.Loading.Downloads, bootstrap.NewHTTPTransport())
blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.Loading, cfg.BlackLists, downloader)
whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.Loading, cfg.WhiteLists, downloader)
blacklistMatcher, blErr := lists.NewListCache(ctx, lists.ListCacheTypeBlacklist,
cfg.Loading, cfg.BlackLists, downloader)
whitelistMatcher, wlErr := lists.NewListCache(ctx, lists.ListCacheTypeWhitelist,
cfg.Loading, cfg.WhiteLists, downloader)
whitelistOnlyGroups := determineWhitelistOnlyGroups(&cfg)
err = multierror.Append(err, blErr, wlErr).ErrorOrNil()
@ -145,14 +150,14 @@ func NewBlockingResolver(
redisClient: redis,
}
res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](expirationcache.Options{
res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](ctx, expirationcache.Options{
CleanupInterval: defaultBlockingCleanUpInterval,
}, func(key string) (val *[]net.IP, ttl time.Duration) {
return res.queryForFQIdentifierIPs(key)
})
if res.redisClient != nil {
setupRedisEnabledSubscriber(res)
setupRedisEnabledSubscriber(ctx, res)
}
err = evt.Bus().SubscribeOnce(evt.ApplicationStarted, func(_ ...string) {
@ -166,20 +171,26 @@ func NewBlockingResolver(
return res, nil
}
func setupRedisEnabledSubscriber(c *BlockingResolver) {
func setupRedisEnabledSubscriber(ctx context.Context, c *BlockingResolver) {
go func() {
for em := range c.redisClient.EnabledChannel {
if em != nil {
c.log().Debug("Received state from redis: ", em)
for {
select {
case em := <-c.redisClient.EnabledChannel:
if em != nil {
c.log().Debug("Received state from redis: ", em)
if em.State {
c.internalEnableBlocking()
} else {
err := c.internalDisableBlocking(em.Duration, em.Groups)
if err != nil {
c.log().Warn("Blocking couldn't be disabled:", err)
if em.State {
c.internalEnableBlocking()
} else {
err := c.internalDisableBlocking(em.Duration, em.Groups)
if err != nil {
c.log().Warn("Blocking couldn't be disabled:", err)
}
}
}
case <-ctx.Done():
return
}
}
}()

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"time"
"github.com/0xERR0R/blocky/config"
@ -50,6 +51,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
sutConfig config.BlockingConfig
m *mockResolver
mockAnswer *dns.Msg
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -59,6 +62,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: config.Duration(time.Minute),
@ -72,7 +78,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
sut, err = NewBlockingResolver(sutConfig, nil, systemResolverBootstrap)
sut, err = NewBlockingResolver(ctx, sutConfig, nil, systemResolverBootstrap)
Expect(err).Should(Succeed())
sut.Next(m)
})
@ -113,7 +120,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
Expect(err).Should(Succeed())
// recreate to trigger a reload
sut, err = NewBlockingResolver(sutConfig, nil, systemResolverBootstrap)
sut, err = NewBlockingResolver(ctx, sutConfig, nil, systemResolverBootstrap)
Expect(err).Should(Succeed())
Eventually(groupCnt, "1s").Should(HaveLen(2))
@ -1111,7 +1118,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
Describe("Create resolver with wrong parameter", func() {
When("Wrong blockType is used", func() {
It("should return error", func() {
_, err := NewBlockingResolver(config.BlockingConfig{
_, err := NewBlockingResolver(ctx, config.BlockingConfig{
BlockType: "wrong",
}, nil, systemResolverBootstrap)
@ -1121,7 +1128,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
When("strategy is failOnError", func() {
It("should fail if lists can't be downloaded", func() {
_, err := NewBlockingResolver(config.BlockingConfig{
_, err := NewBlockingResolver(ctx, config.BlockingConfig{
BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources("wrongPath")},
WhiteLists: map[string][]config.BytesSource{"whitelist": config.NewBytesSources("wrongPath")},
Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFailOnError},
@ -1155,7 +1162,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
BlockTTL: config.Duration(time.Minute),
}
sut, err = NewBlockingResolver(sutConfig, redisClient, systemResolverBootstrap)
sut, err = NewBlockingResolver(ctx, sutConfig, redisClient, systemResolverBootstrap)
Expect(err).Should(Succeed())
})
JustAfterEach(func() {

View File

@ -42,7 +42,7 @@ type Bootstrap struct {
// NewBootstrap creates and returns a new Bootstrap.
// Internally, it uses a CachingResolver and an UpstreamResolver.
func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err error) {
logger := log.PrefixedLog("bootstrap")
timeout := defaultTimeout
@ -97,7 +97,8 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
b.resolver = Chain(
NewFilteringResolver(cfg.Filtering),
newCachingResolver(cachingCfg, nil, false), // false: no metrics, to not overwrite the main blocking resolver ones
// false: no metrics, to not overwrite the main blocking resolver ones
newCachingResolver(ctx, cachingCfg, nil, false),
parallelResolver,
)

View File

@ -25,6 +25,8 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
var (
sut *Bootstrap
sutConfig *config.Config
ctx context.Context
cancelFn context.CancelFunc
err error
)
@ -41,10 +43,13 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
},
},
}
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
})
JustBeforeEach(func() {
sut, err = NewBootstrap(sutConfig)
sut, err = NewBootstrap(ctx, sutConfig)
Expect(err).Should(Succeed())
})
@ -98,7 +103,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
},
}
_, err := NewBootstrap(&cfg)
_, err := NewBootstrap(ctx, &cfg)
Expect(err).ShouldNot(Succeed())
})
})
@ -140,7 +145,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
},
}
_, err := NewBootstrap(&cfg)
_, err := NewBootstrap(ctx, &cfg)
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring("must use IP instead of hostname"))
})
@ -184,7 +189,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
},
}
_, err := NewBootstrap(&cfg)
_, err := NewBootstrap(ctx, &cfg)
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring("no IPs configured"))
})
@ -203,7 +208,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
},
}
_, err := NewBootstrap(&cfg)
_, err := NewBootstrap(ctx, &cfg)
Expect(err).Should(Succeed())
})
})

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"fmt"
"math"
"sync/atomic"
@ -35,11 +36,18 @@ type CachingResolver struct {
}
// NewCachingResolver creates a new resolver instance
func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) *CachingResolver {
return newCachingResolver(cfg, redis, true)
func NewCachingResolver(ctx context.Context,
cfg config.CachingConfig,
redis *redis.Client,
) *CachingResolver {
return newCachingResolver(ctx, cfg, redis, true)
}
func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetricEvents bool) *CachingResolver {
func newCachingResolver(ctx context.Context,
cfg config.CachingConfig,
redis *redis.Client,
emitMetricEvents bool,
) *CachingResolver {
c := &CachingResolver{
configurable: withConfig(&cfg),
typed: withType("caching"),
@ -48,17 +56,17 @@ func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetri
emitMetricEvents: emitMetricEvents,
}
configureCaches(c, &cfg)
configureCaches(ctx, c, &cfg)
if c.redisClient != nil {
setupRedisCacheSubscriber(c)
setupRedisCacheSubscriber(ctx, c)
c.redisClient.GetRedisCache()
}
return c
}
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.CachingConfig) {
options := expirationcache.Options{
CleanupInterval: defaultCachingCleanUpInterval,
MaxSize: uint(cfg.MaxItemsCount),
@ -91,9 +99,9 @@ func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
},
}
c.resultCache = expirationcache.NewPrefetchingCache(prefetchingOptions)
c.resultCache = expirationcache.NewPrefetchingCache(ctx, prefetchingOptions)
} else {
c.resultCache = expirationcache.NewCache[dns.Msg](options)
c.resultCache = expirationcache.NewCache[dns.Msg](ctx, options)
}
}
@ -117,13 +125,19 @@ func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*dns.Msg, time.Dura
return nil, 0
}
func setupRedisCacheSubscriber(c *CachingResolver) {
func setupRedisCacheSubscriber(ctx context.Context, c *CachingResolver) {
go func() {
for rc := range c.redisClient.CacheChannel {
if rc != nil {
c.log().Debug("Received key from redis: ", rc.Key)
ttl := c.adjustTTLs(rc.Response.Res.Answer)
c.putInCache(rc.Key, rc.Response, ttl, false)
for {
select {
case rc := <-c.redisClient.CacheChannel:
if rc != nil {
c.log().Debug("Received key from redis: ", rc.Key)
ttl := c.adjustTTLs(rc.Response.Res.Answer)
c.putInCache(rc.Key, rc.Response, ttl, false)
}
case <-ctx.Done():
return
}
}
}()

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"fmt"
"time"
@ -26,6 +27,8 @@ var _ = Describe("CachingResolver", func() {
sutConfig config.CachingConfig
m *mockResolver
mockAnswer *dns.Msg
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -43,7 +46,10 @@ var _ = Describe("CachingResolver", func() {
})
JustBeforeEach(func() {
sut = NewCachingResolver(sutConfig, nil)
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut = NewCachingResolver(ctx, sutConfig, nil)
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
sut.Next(m)
@ -79,7 +85,7 @@ var _ = Describe("CachingResolver", func() {
It("should prefetch domain if query count > threshold", func() {
// prepare resolver, set smaller caching times for testing
prefetchThreshold := 5
configureCaches(sut, &sutConfig)
configureCaches(ctx, sut, &sutConfig)
domainPrefetched := make(chan bool, 1)
prefetchHitDomain := make(chan bool, 1)
@ -725,7 +731,7 @@ var _ = Describe("CachingResolver", func() {
}
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 1000, A, "1.1.1.1")
sut = NewCachingResolver(sutConfig, redisClient)
sut = NewCachingResolver(ctx, sutConfig, redisClient)
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
sut.Next(m)

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"net"
"strings"
"time"
@ -26,7 +27,7 @@ type ClientNamesResolver struct {
}
// NewClientNamesResolver creates new resolver instance
func NewClientNamesResolver(
func NewClientNamesResolver(ctx context.Context,
cfg config.ClientLookupConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (cr *ClientNamesResolver, err error) {
var r Resolver
@ -41,7 +42,7 @@ func NewClientNamesResolver(
configurable: withConfig(&cfg),
typed: withType("client_names"),
cache: expirationcache.NewCache[[]string](expirationcache.Options{
cache: expirationcache.NewCache[[]string](ctx, expirationcache.Options{
CleanupInterval: time.Hour,
}),
externalResolver: r,

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"errors"
"net"
@ -21,6 +22,8 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
sut *ClientNamesResolver
sutConfig config.ClientLookupConfig
m *mockResolver
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -31,8 +34,10 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
JustBeforeEach(func() {
var err error
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut, err = NewClientNamesResolver(sutConfig, nil, false)
sut, err = NewClientNamesResolver(ctx, sutConfig, nil, false)
Expect(err).Should(Succeed())
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
@ -361,9 +366,9 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
Describe("Connstruction", func() {
When("upstream is invalid", func() {
It("errors during construction", func() {
b := newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
r, err := NewClientNamesResolver(config.ClientLookupConfig{
r, err := NewClientNamesResolver(ctx, config.ClientLookupConfig{
Upstream: config.Upstream{Host: "example.com"},
}, b, true)

View File

@ -1,6 +1,8 @@
package resolver
import (
"context"
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
@ -184,7 +186,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
When("upstream is invalid", func() {
It("errors during construction", func() {
b := newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
ctx, cancelFn := context.WithCancel(context.Background())
DeferCleanup(cancelFn)
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
r, err := NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{
Mapping: config.ConditionalUpstreamMapping{

View File

@ -34,7 +34,10 @@ type HostsFileResolver struct {
downloader lists.FileDownloader
}
func NewHostsFileResolver(cfg config.HostsFileConfig, bootstrap *Bootstrap) (*HostsFileResolver, error) {
func NewHostsFileResolver(ctx context.Context,
cfg config.HostsFileConfig,
bootstrap *Bootstrap,
) (*HostsFileResolver, error) {
r := HostsFileResolver{
configurable: withConfig(&cfg),
typed: withType("hosts_file"),
@ -42,7 +45,7 @@ func NewHostsFileResolver(cfg config.HostsFileConfig, bootstrap *Bootstrap) (*Ho
downloader: lists.NewDownloader(cfg.Loading.Downloads, bootstrap.NewHTTPTransport()),
}
err := cfg.Loading.StartPeriodicRefresh(r.loadSources, func(err error) {
err := cfg.Loading.StartPeriodicRefresh(ctx, r.loadSources, func(err error) {
r.log().WithError(err).Errorf("could not load hosts files")
})
if err != nil {

View File

@ -53,7 +53,9 @@ var _ = Describe("HostsFileResolver", func() {
JustBeforeEach(func() {
var err error
sut, err = NewHostsFileResolver(sutConfig, systemResolverBootstrap)
ctx, cancelFn := context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut, err = NewHostsFileResolver(ctx, sutConfig, systemResolverBootstrap)
Expect(err).Should(Succeed())
m = &mockResolver{}

View File

@ -118,10 +118,10 @@ func autoAnswer(qType dns.Type, qName string) (*dns.Msg, error) {
}
// newTestBootstrap creates a test Bootstrap
func newTestBootstrap(response *dns.Msg) *Bootstrap {
func newTestBootstrap(ctx context.Context, response *dns.Msg) *Bootstrap {
bootstrapUpstream := &mockResolver{}
b, err := NewBootstrap(&config.Config{})
b, err := NewBootstrap(ctx, &config.Config{})
util.FatalOnError("can't create bootstrap", err)
b.resolver = bootstrapUpstream

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"strings"
"time"
@ -24,6 +25,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
sut *ParallelBestResolver
sutMapping config.UpstreamGroups
sutVerify bool
ctx context.Context
cancelFn context.CancelFunc
err error
@ -37,6 +40,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {
config.Upstream{
@ -111,7 +117,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
When("no upstream resolvers can be reached", func() {
BeforeEach(func() {
bootstrap = newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
bootstrap = newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
sutMapping = config.UpstreamGroups{
upstreamDefaultCfgName: {
@ -328,7 +334,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
When("upstream is invalid", func() {
It("errors during construction", func() {
b := newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
r, err := NewParallelBestResolver(config.UpstreamsConfig{
Groups: config.UpstreamGroups{"test": {{Host: "example.com"}}},

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"time"
"github.com/0xERR0R/blocky/config"
@ -29,7 +30,7 @@ type QueryLoggingResolver struct {
}
// NewQueryLoggingResolver returns a new resolver instance
func NewQueryLoggingResolver(cfg config.QueryLogConfig) *QueryLoggingResolver {
func NewQueryLoggingResolver(ctx context.Context, cfg config.QueryLogConfig) *QueryLoggingResolver {
logger := log.PrefixedLog(queryLoggingResolverType)
var writer querylog.Writer
@ -43,10 +44,10 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) *QueryLoggingResolver {
case config.QueryLogTypeCsvClient:
writer, err = querylog.NewCSVWriter(cfg.Target, true, cfg.LogRetentionDays)
case config.QueryLogTypeMysql:
writer, err = querylog.NewDatabaseWriter("mysql", cfg.Target, cfg.LogRetentionDays,
writer, err = querylog.NewDatabaseWriter(ctx, "mysql", cfg.Target, cfg.LogRetentionDays,
cfg.FlushInterval.ToDuration())
case config.QueryLogTypePostgresql:
writer, err = querylog.NewDatabaseWriter("postgresql", cfg.Target, cfg.LogRetentionDays,
writer, err = querylog.NewDatabaseWriter(ctx, "postgresql", cfg.Target, cfg.LogRetentionDays,
cfg.FlushInterval.ToDuration())
case config.QueryLogTypeConsole:
writer = querylog.NewLoggerWriter()
@ -81,23 +82,27 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) *QueryLoggingResolver {
writer: writer,
}
go resolver.writeLog()
go resolver.writeLog(ctx)
if cfg.LogRetentionDays > 0 {
go resolver.periodicCleanUp()
go resolver.periodicCleanUp(ctx)
}
return &resolver
}
// triggers periodically cleanup of old log files
func (r *QueryLoggingResolver) periodicCleanUp() {
func (r *QueryLoggingResolver) periodicCleanUp(ctx context.Context) {
ticker := time.NewTicker(cleanUpRunPeriod)
defer ticker.Stop()
for {
<-ticker.C
r.doCleanUp()
select {
case <-ticker.C:
r.doCleanUp()
case <-ctx.Done():
return
}
}
}
@ -164,18 +169,23 @@ func (r *QueryLoggingResolver) createLogEntry(request *model.Request, response *
}
// write entry: if log directory is configured, write to log file
func (r *QueryLoggingResolver) writeLog() {
for logEntry := range r.logChan {
start := time.Now()
func (r *QueryLoggingResolver) writeLog(ctx context.Context) {
for {
select {
case logEntry := <-r.logChan:
start := time.Now()
r.writer.Write(logEntry)
r.writer.Write(logEntry)
halfCap := cap(r.logChan) / 2 //nolint:gomnd
halfCap := cap(r.logChan) / 2 //nolint:gomnd
// if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.)
if len(r.logChan) > halfCap {
r.log().WithField("channel_len",
len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds())
// if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.)
if len(r.logChan) > halfCap {
r.log().WithField("channel_len",
len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds())
}
case <-ctx.Done():
return
}
}
}

View File

@ -2,6 +2,7 @@ package resolver
import (
"bufio"
"context"
"encoding/csv"
"errors"
"fmt"
@ -63,8 +64,10 @@ var _ = Describe("QueryLoggingResolver", func() {
sutConfig.SetDefaults() // not called when using a struct literal
}
sut = NewQueryLoggingResolver(sutConfig)
DeferCleanup(func() { close(sut.logChan) })
ctx, cancelFn := context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut = NewQueryLoggingResolver(ctx, sutConfig)
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil)
sut.Next(m)
@ -151,7 +154,7 @@ var _ = Describe("QueryLoggingResolver", func() {
g.Expect(csvLines[0][7]).Should(Equal("NOERROR"))
g.Expect(csvLines[0][8]).Should(Equal("RESOLVED"))
g.Expect(csvLines[0][9]).Should(Equal("A"))
}, "1s").Should(Succeed())
}).Should(Succeed())
})
By("check log for client2", func() {
@ -169,7 +172,7 @@ var _ = Describe("QueryLoggingResolver", func() {
g.Expect(csvLines[0][7]).Should(Equal("NOERROR"))
g.Expect(csvLines[0][8]).Should(Equal("RESOLVED"))
g.Expect(csvLines[0][9]).Should(Equal("A"))
}, "1s").Should(Succeed())
}).Should(Succeed())
})
})
})

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"strings"
"github.com/0xERR0R/blocky/config"
@ -109,16 +110,24 @@ var _ = Describe("Resolver", func() {
})
Describe("Name", func() {
var (
ctx context.Context
cancelFn context.CancelFunc
)
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
})
When("'Name' is called", func() {
It("should return resolver name", func() {
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
br, _ := NewBlockingResolver(ctx, config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
name := Name(br)
Expect(name).Should(Equal("blocking"))
})
})
When("'Name' is called on a NamedResolver", func() {
It("should return its custom name", func() {
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
br, _ := NewBlockingResolver(ctx, config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
cfg := config.RewriterConfig{Rewrite: map[string]string{"not": "empty"}}
r := NewRewriterResolver(cfg, br)

View File

@ -2,6 +2,7 @@ package server
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
@ -113,7 +114,7 @@ func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
// NewServer creates new server instance with passed config
//
//nolint:funlen
func NewServer(cfg *config.Config) (server *Server, err error) {
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
log.ConfigureLogger(&cfg.Log)
var cert tls.Certificate
@ -145,7 +146,7 @@ func NewServer(cfg *config.Config) (server *Server, err error) {
metrics.RegisterEventListeners()
bootstrap, err := resolver.NewBootstrap(cfg)
bootstrap, err := resolver.NewBootstrap(ctx, cfg)
if err != nil {
return nil, err
}
@ -155,7 +156,7 @@ func NewServer(cfg *config.Config) (server *Server, err error) {
return nil, redisErr
}
queryResolver, queryError := createQueryResolver(cfg, bootstrap, redisClient)
queryResolver, queryError := createQueryResolver(ctx, cfg, bootstrap, redisClient)
if queryError != nil {
return nil, queryError
}
@ -385,6 +386,7 @@ func createSelfSignedCert() (tls.Certificate, error) {
}
func createQueryResolver(
ctx context.Context,
cfg *config.Config,
bootstrap *resolver.Bootstrap,
redisClient *redis.Client,
@ -396,10 +398,10 @@ func createQueryResolver(
upstreamTree, utErr := resolver.NewUpstreamTreeResolver(cfg.Upstreams, upstreamBranches)
blocking, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap)
clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
blocking, blErr := resolver.NewBlockingResolver(ctx, cfg.Blocking, redisClient, bootstrap)
clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
hostsFile, hfErr := resolver.NewHostsFileResolver(cfg.HostsFile, bootstrap)
hostsFile, hfErr := resolver.NewHostsFileResolver(ctx, cfg.HostsFile, bootstrap)
err = multierror.Append(
multierror.Prefix(utErr, "upstream tree resolver: "),
@ -417,12 +419,12 @@ func createQueryResolver(
resolver.NewFqdnOnlyResolver(cfg.FqdnOnly),
clientNames,
resolver.NewEdeResolver(cfg.Ede),
resolver.NewQueryLoggingResolver(cfg.QueryLog),
resolver.NewQueryLoggingResolver(ctx, cfg.QueryLog),
resolver.NewMetricsResolver(cfg.Prometheus),
resolver.NewRewriterResolver(cfg.CustomDNS.RewriterConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)),
hostsFile,
blocking,
resolver.NewCachingResolver(cfg.Caching, redisClient),
resolver.NewCachingResolver(ctx, cfg.Caching, redisClient),
resolver.NewRewriterResolver(cfg.Conditional.RewriterConfig, condUpstream),
resolver.NewSpecialUseDomainNamesResolver(cfg.SUDN),
upstreamTree,
@ -513,7 +515,7 @@ const (
)
// Start starts the server
func (s *Server) Start(errCh chan<- error) {
func (s *Server) Start(ctx context.Context, errCh chan<- error) {
logger().Info("Starting server")
for _, srv := range s.dnsServers {
@ -572,7 +574,7 @@ func (s *Server) Start(errCh chan<- error) {
}()
}
registerPrintConfigurationTrigger(s)
registerPrintConfigurationTrigger(ctx, s)
}
// Stop stops the server

View File

@ -4,19 +4,25 @@
package server
import (
"context"
"os"
"os/signal"
"syscall"
)
func registerPrintConfigurationTrigger(s *Server) {
func registerPrintConfigurationTrigger(ctx context.Context, s *Server) {
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGUSR1)
go func() {
for {
<-signals
s.printConfiguration()
select {
case <-signals:
s.printConfiguration()
case <-ctx.Done():
return
}
}
}()
}

View File

@ -2,6 +2,7 @@ package server
import (
"bytes"
"context"
"encoding/base64"
"io"
"net"
@ -32,6 +33,8 @@ var (
var _ = BeforeSuite(func() {
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
ctx, cancelFn := context.WithCancel(context.Background())
DeferCleanup(cancelFn)
googleMockUpstream := resolver.NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
if request.Question[0].Name == "error." {
return nil
@ -102,7 +105,7 @@ var _ = BeforeSuite(func() {
Expect(youtubeFile.Error).Should(Succeed())
// create server
sut, err = NewServer(&config.Config{
sut, err = NewServer(ctx, &config.Config{
CustomDNS: config.CustomDNSConfig{
CustomTTL: config.Duration(3600 * time.Second),
Mapping: config.CustomDNSMapping{
@ -169,7 +172,7 @@ var _ = BeforeSuite(func() {
// start server
go func() {
sut.Start(errChan)
sut.Start(ctx, errChan)
}()
DeferCleanup(sut.Stop)
@ -177,6 +180,14 @@ var _ = BeforeSuite(func() {
})
var _ = Describe("Running DNS server", func() {
var (
ctx context.Context
cancelFn context.CancelFunc
)
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
})
Describe("performing DNS request with running server", func() {
BeforeEach(func() {
mockClientName.Store("")
@ -555,14 +566,14 @@ var _ = Describe("Running DNS server", func() {
})
When("Server is created", func() {
It("is created without redis connection", func() {
_, err := NewServer(&cfg)
_, err := NewServer(ctx, &cfg)
Expect(err).Should(Succeed())
})
It("can't be created if redis server is unavailable", func() {
cfg.Redis.Required = true
_, err := NewServer(&cfg)
_, err := NewServer(ctx, &cfg)
Expect(err).ShouldNot(Succeed())
})
@ -573,7 +584,7 @@ var _ = Describe("Running DNS server", func() {
When("Server start is called", func() {
It("start was called 2 times, start should fail", func() {
// create server
server, err := NewServer(&config.Config{
server, err := NewServer(ctx, &config.Config{
Upstreams: config.UpstreamsConfig{
Groups: map[string][]config.Upstream{
"default": {config.Upstream{Net: config.NetProtocolTcpUdp, Host: "4.4.4.4", Port: 53}},
@ -598,14 +609,14 @@ var _ = Describe("Running DNS server", func() {
errChan := make(chan error, 10)
// start server
go server.Start(errChan)
go server.Start(ctx, errChan)
DeferCleanup(server.Stop)
Consistently(errChan, "1s").ShouldNot(Receive())
// start again -> should fail
server.Start(errChan)
server.Start(ctx, errChan)
Eventually(errChan).Should(Receive())
})
@ -615,7 +626,7 @@ var _ = Describe("Running DNS server", func() {
When("Stop is called", func() {
It("stop was called 2 times, start should fail", func() {
// create server
server, err := NewServer(&config.Config{
server, err := NewServer(ctx, &config.Config{
Upstreams: config.UpstreamsConfig{
Groups: map[string][]config.Upstream{
"default": {config.Upstream{Net: config.NetProtocolTcpUdp, Host: "4.4.4.4", Port: 53}},
@ -641,7 +652,7 @@ var _ = Describe("Running DNS server", func() {
// start server
go func() {
server.Start(errChan)
server.Start(ctx, errChan)
}()
time.Sleep(100 * time.Millisecond)
@ -681,7 +692,7 @@ var _ = Describe("Running DNS server", func() {
Describe("create query resolver", func() {
When("some upstream returns error", func() {
It("create query resolver should return error", func() {
r, err := createQueryResolver(&config.Config{
r, err := createQueryResolver(ctx, &config.Config{
StartVerifyUpstream: true,
Upstreams: config.UpstreamsConfig{
Groups: config.UpstreamGroups{
@ -736,8 +747,7 @@ var _ = Describe("Running DNS server", func() {
cfg.Ports = config.PortsConfig{
HTTPS: []string{":14443"},
}
sut, err := NewServer(&cfg)
sut, err := NewServer(ctx, &cfg)
Expect(err).Should(Succeed())
Expect(sut.cert.Certificate).ShouldNot(BeNil())
})