mirror of https://github.com/0xERR0R/blocky.git
refactor: pass context for goroutine shutdown (#1187)
This commit is contained in:
parent
d9e91da686
commit
33ea933015
|
@ -1,6 +1,7 @@
|
|||
package expirationcache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
|
@ -47,11 +48,11 @@ type OnCacheMissCallback func(key string)
|
|||
// OnAfterPutCallback will be called after put, receives new element count as parameter
|
||||
type OnAfterPutCallback func(newSize int)
|
||||
|
||||
func NewCache[T any](options Options) *ExpiringLRUCache[T] {
|
||||
return NewCacheWithOnExpired[T](options, nil)
|
||||
func NewCache[T any](ctx context.Context, options Options) *ExpiringLRUCache[T] {
|
||||
return NewCacheWithOnExpired[T](ctx, options, nil)
|
||||
}
|
||||
|
||||
func NewCacheWithOnExpired[T any](options Options,
|
||||
func NewCacheWithOnExpired[T any](ctx context.Context, options Options,
|
||||
onExpirationFn OnExpirationCallback[T],
|
||||
) *ExpiringLRUCache[T] {
|
||||
l, _ := lru.New(defaultSize)
|
||||
|
@ -90,18 +91,22 @@ func NewCacheWithOnExpired[T any](options Options,
|
|||
c.preExpirationFn = onExpirationFn
|
||||
}
|
||||
|
||||
go periodicCleanup(c)
|
||||
go periodicCleanup(ctx, c)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func periodicCleanup[T any](c *ExpiringLRUCache[T]) {
|
||||
func periodicCleanup[T any](ctx context.Context, c *ExpiringLRUCache[T]) {
|
||||
ticker := time.NewTicker(c.cleanUpInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
<-ticker.C
|
||||
c.cleanUp()
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.cleanUp()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package expirationcache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
|
@ -8,14 +9,22 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("Expiration cache", func() {
|
||||
var (
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
})
|
||||
Describe("Basic operations", func() {
|
||||
When("string cache was created", func() {
|
||||
It("Initial cache should be empty", func() {
|
||||
cache := NewCache[string](Options{})
|
||||
cache := NewCache[string](ctx, Options{})
|
||||
Expect(cache.TotalCount()).Should(Equal(0))
|
||||
})
|
||||
It("Initial cache should not contain any elements", func() {
|
||||
cache := NewCache[string](Options{})
|
||||
cache := NewCache[string](ctx, Options{})
|
||||
val, expiration := cache.Get("key1")
|
||||
Expect(val).Should(BeNil())
|
||||
Expect(expiration).Should(Equal(time.Duration(0)))
|
||||
|
@ -23,7 +32,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
})
|
||||
When("Put new value with positive TTL", func() {
|
||||
It("Should return the value before element expires", func() {
|
||||
cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond})
|
||||
cache := NewCache[string](ctx, Options{CleanupInterval: 100 * time.Millisecond})
|
||||
v := "v1"
|
||||
cache.Put("key1", &v, 50*time.Millisecond)
|
||||
val, expiration := cache.Get("key1")
|
||||
|
@ -33,7 +42,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
Expect(cache.TotalCount()).Should(Equal(1))
|
||||
})
|
||||
It("Should return nil after expiration", func() {
|
||||
cache := NewCache[string](Options{CleanupInterval: 100 * time.Millisecond})
|
||||
cache := NewCache[string](ctx, Options{CleanupInterval: 100 * time.Millisecond})
|
||||
v := "v1"
|
||||
cache.Put("key1", &v, 50*time.Millisecond)
|
||||
|
||||
|
@ -47,12 +56,12 @@ var _ = Describe("Expiration cache", func() {
|
|||
// wait for cleanup run
|
||||
Eventually(func() int {
|
||||
return cache.lru.Len()
|
||||
}, "100ms").Should(Equal(0))
|
||||
}).Should(Equal(0))
|
||||
})
|
||||
})
|
||||
When("Put new value without expiration", func() {
|
||||
It("Should not cache the value", func() {
|
||||
cache := NewCache[string](Options{CleanupInterval: 50 * time.Millisecond})
|
||||
cache := NewCache[string](ctx, Options{CleanupInterval: 50 * time.Millisecond})
|
||||
v := "x"
|
||||
cache.Put("key1", &v, 0)
|
||||
val, expiration := cache.Get("key1")
|
||||
|
@ -63,7 +72,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
})
|
||||
When("Put updated value", func() {
|
||||
It("Should return updated value", func() {
|
||||
cache := NewCache[string](Options{})
|
||||
cache := NewCache[string](ctx, Options{})
|
||||
v1 := "v1"
|
||||
v2 := "v2"
|
||||
cache.Put("key1", &v1, 50*time.Millisecond)
|
||||
|
@ -79,7 +88,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
})
|
||||
When("Purging after usage", func() {
|
||||
It("Should be empty after purge", func() {
|
||||
cache := NewCache[string](Options{})
|
||||
cache := NewCache[string](ctx, Options{})
|
||||
v1 := "y"
|
||||
cache.Put("key1", &v1, time.Second)
|
||||
|
||||
|
@ -97,7 +106,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
onCacheHitChannel := make(chan string, 10)
|
||||
onCacheMissChannel := make(chan string, 10)
|
||||
onAfterPutChannel := make(chan int, 10)
|
||||
cache := NewCache[string](Options{
|
||||
cache := NewCache[string](ctx, Options{
|
||||
OnCacheHitFn: func(key string) {
|
||||
onCacheHitChannel <- key
|
||||
},
|
||||
|
@ -146,7 +155,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
return &v2, time.Second
|
||||
}
|
||||
|
||||
cache := NewCacheWithOnExpired[string](Options{}, fn)
|
||||
cache := NewCacheWithOnExpired[string](ctx, Options{}, fn)
|
||||
v1 := "v1"
|
||||
cache.Put("key1", &v1, 50*time.Millisecond)
|
||||
|
||||
|
@ -165,7 +174,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
|
||||
return &v2, time.Second
|
||||
}
|
||||
cache := NewCacheWithOnExpired[string](Options{}, fn)
|
||||
cache := NewCacheWithOnExpired[string](ctx, Options{}, fn)
|
||||
v1 := "somval"
|
||||
cache.Put("key1", &v1, time.Millisecond)
|
||||
|
||||
|
@ -186,7 +195,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
fn := func(key string) (val *string, ttl time.Duration) {
|
||||
return nil, 0
|
||||
}
|
||||
cache := NewCacheWithOnExpired[string](Options{CleanupInterval: 100 * time.Microsecond}, fn)
|
||||
cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn)
|
||||
v1 := "z"
|
||||
cache.Put("key1", &v1, 50*time.Millisecond)
|
||||
|
||||
|
@ -199,7 +208,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
Describe("LRU behaviour", func() {
|
||||
When("Defined max size is reached", func() {
|
||||
It("should remove old elements", func() {
|
||||
cache := NewCache[string](Options{MaxSize: 3})
|
||||
cache := NewCache[string](ctx, Options{MaxSize: 3})
|
||||
|
||||
v1 := "val1"
|
||||
v2 := "val2"
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package expirationcache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
@ -39,9 +40,9 @@ type PrefetchingOptions[T any] struct {
|
|||
|
||||
type PrefetchingCacheOption[T any] func(c *PrefetchingExpiringLRUCache[cacheValue[T]])
|
||||
|
||||
func NewPrefetchingCache[T any](options PrefetchingOptions[T]) *PrefetchingExpiringLRUCache[T] {
|
||||
func NewPrefetchingCache[T any](ctx context.Context, options PrefetchingOptions[T]) *PrefetchingExpiringLRUCache[T] {
|
||||
pc := &PrefetchingExpiringLRUCache[T]{
|
||||
prefetchingNameCache: NewCache[atomic.Uint32](Options{
|
||||
prefetchingNameCache: NewCache[atomic.Uint32](ctx, Options{
|
||||
CleanupInterval: time.Minute,
|
||||
MaxSize: uint(options.PrefetchMaxItemsCount),
|
||||
OnAfterPutFn: options.OnPrefetchAfterPut,
|
||||
|
@ -53,7 +54,7 @@ func NewPrefetchingCache[T any](options PrefetchingOptions[T]) *PrefetchingExpir
|
|||
onPrefetchCacheHit: options.OnPrefetchCacheHit,
|
||||
}
|
||||
|
||||
pc.cache = NewCacheWithOnExpired[cacheValue[T]](options.Options, pc.onExpired)
|
||||
pc.cache = NewCacheWithOnExpired[cacheValue[T]](ctx, options.Options, pc.onExpired)
|
||||
|
||||
return pc
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package expirationcache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
|
@ -8,21 +9,29 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("Prefetching expiration cache", func() {
|
||||
var (
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
})
|
||||
Describe("Basic operations", func() {
|
||||
When("string cache was created", func() {
|
||||
It("Initial cache should be empty", func() {
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
|
||||
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{})
|
||||
Expect(cache.TotalCount()).Should(Equal(0))
|
||||
})
|
||||
It("Initial cache should not contain any elements", func() {
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
|
||||
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{})
|
||||
val, expiration := cache.Get("key1")
|
||||
Expect(val).Should(BeNil())
|
||||
Expect(expiration).Should(Equal(time.Duration(0)))
|
||||
})
|
||||
|
||||
It("Should work as cache (basic operations)", func() {
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{})
|
||||
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{})
|
||||
v := "v1"
|
||||
cache.Put("key1", &v, 50*time.Millisecond)
|
||||
|
||||
|
@ -39,7 +48,7 @@ var _ = Describe("Prefetching expiration cache", func() {
|
|||
})
|
||||
Context("Prefetching", func() {
|
||||
It("Should prefetch element", func() {
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
|
||||
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{
|
||||
Options: Options{
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
},
|
||||
|
@ -71,7 +80,7 @@ var _ = Describe("Prefetching expiration cache", func() {
|
|||
})
|
||||
})
|
||||
It("Should not prefetch element", func() {
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
|
||||
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{
|
||||
Options: Options{
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
},
|
||||
|
@ -100,7 +109,7 @@ var _ = Describe("Prefetching expiration cache", func() {
|
|||
})
|
||||
})
|
||||
It("With default config (threshold = 0) should always prefetch", func() {
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
|
||||
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{
|
||||
Options: Options{
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
},
|
||||
|
@ -128,7 +137,7 @@ var _ = Describe("Prefetching expiration cache", func() {
|
|||
onPrefetchAfterPutChannel := make(chan int, 10)
|
||||
onPrefetchEntryReloaded := make(chan string, 10)
|
||||
onnPrefetchCacheHit := make(chan string, 10)
|
||||
cache := NewPrefetchingCache[string](PrefetchingOptions[string]{
|
||||
cache := NewPrefetchingCache[string](ctx, PrefetchingOptions[string]{
|
||||
Options: Options{
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
},
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
@ -43,7 +44,10 @@ func startServer(_ *cobra.Command, _ []string) error {
|
|||
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
srv, err := server.NewServer(cfg)
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
defer cancelFn()
|
||||
|
||||
srv, err := server.NewServer(ctx, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't start server: %w", err)
|
||||
}
|
||||
|
@ -51,7 +55,7 @@ func startServer(_ *cobra.Command, _ []string) error {
|
|||
const errChanSize = 10
|
||||
errChan := make(chan error, errChanSize)
|
||||
|
||||
srv.Start(errChan)
|
||||
srv.Start(ctx, errChan)
|
||||
|
||||
var terminationErr error
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ var _ = Describe("Serve command", func() {
|
|||
|
||||
By("terminate with signal", func() {
|
||||
var startError error
|
||||
Eventually(errChan).Should(Receive(&startError))
|
||||
Eventually(errChan, "10s").Should(Receive(&startError))
|
||||
Expect(startError).ShouldNot(BeNil())
|
||||
Expect(startError.Error()).Should(ContainSubstring("address already in use"))
|
||||
})
|
||||
|
|
|
@ -314,7 +314,10 @@ func (c *SourceLoadingConfig) LogConfig(logger *logrus.Entry) {
|
|||
log.WithIndent(logger, " ", c.Downloads.LogConfig)
|
||||
}
|
||||
|
||||
func (c *SourceLoadingConfig) StartPeriodicRefresh(refresh func(context.Context) error, logErr func(error)) error {
|
||||
func (c *SourceLoadingConfig) StartPeriodicRefresh(ctx context.Context,
|
||||
refresh func(context.Context) error,
|
||||
logErr func(error),
|
||||
) error {
|
||||
refreshAndRecover := func(ctx context.Context) (rerr error) {
|
||||
defer func() {
|
||||
if val := recover(); val != nil {
|
||||
|
@ -331,20 +334,29 @@ func (c *SourceLoadingConfig) StartPeriodicRefresh(refresh func(context.Context)
|
|||
}
|
||||
|
||||
if c.RefreshPeriod > 0 {
|
||||
go c.periodically(refreshAndRecover, logErr)
|
||||
go c.periodically(ctx, refreshAndRecover, logErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SourceLoadingConfig) periodically(refresh func(context.Context) error, logErr func(error)) {
|
||||
func (c *SourceLoadingConfig) periodically(ctx context.Context,
|
||||
refresh func(context.Context) error,
|
||||
logErr func(error),
|
||||
) {
|
||||
ticker := time.NewTicker(c.RefreshPeriod.ToDuration())
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
err := refresh(context.Background())
|
||||
if err != nil {
|
||||
logErr(err)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := refresh(ctx)
|
||||
if err != nil {
|
||||
logErr(err)
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -688,6 +688,14 @@ bootstrapDns:
|
|||
})
|
||||
|
||||
Describe("SourceLoadingConfig", func() {
|
||||
var (
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
})
|
||||
It("handles panics", func() {
|
||||
sut := SourceLoadingConfig{
|
||||
Strategy: StartStrategyTypeFailOnError,
|
||||
|
@ -695,7 +703,7 @@ bootstrapDns:
|
|||
|
||||
panicMsg := "panic value"
|
||||
|
||||
err := sut.StartPeriodicRefresh(func(context.Context) error {
|
||||
err := sut.StartPeriodicRefresh(ctx, func(context.Context) error {
|
||||
panic(panicMsg)
|
||||
}, func(err error) {
|
||||
Expect(err).Should(MatchError(ContainSubstring(panicMsg)))
|
||||
|
@ -715,7 +723,7 @@ bootstrapDns:
|
|||
|
||||
var call atomic.Int32
|
||||
|
||||
err := sut.StartPeriodicRefresh(func(context.Context) error {
|
||||
err := sut.StartPeriodicRefresh(ctx, func(context.Context) error {
|
||||
call := call.Add(1)
|
||||
calls <- call
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ func (b *ListCache) LogConfig(logger *logrus.Entry) {
|
|||
}
|
||||
|
||||
// NewListCache creates new list instance
|
||||
func NewListCache(
|
||||
func NewListCache(ctx context.Context,
|
||||
t ListCacheType, cfg config.SourceLoadingConfig,
|
||||
groupSources map[string][]config.BytesSource, downloader FileDownloader,
|
||||
) (*ListCache, error) {
|
||||
|
@ -72,7 +72,7 @@ func NewListCache(
|
|||
downloader: downloader,
|
||||
}
|
||||
|
||||
err := cfg.StartPeriodicRefresh(c.refresh, func(err error) {
|
||||
err := cfg.StartPeriodicRefresh(ctx, c.refresh, func(err error) {
|
||||
logger().WithError(err).Errorf("could not init %s", t)
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package lists
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -19,7 +20,7 @@ func BenchmarkRefresh(b *testing.B) {
|
|||
RefreshPeriod: config.Duration(-1),
|
||||
}
|
||||
downloader := NewDownloader(config.DownloaderConfig{}, nil)
|
||||
cache, _ := NewListCache(ListCacheTypeBlacklist, cfg, lists, downloader)
|
||||
cache, _ := NewListCache(context.Background(), ListCacheTypeBlacklist, cfg, lists, downloader)
|
||||
|
||||
b.ReportAllocs()
|
||||
|
||||
|
|
|
@ -36,11 +36,16 @@ var _ = Describe("ListCache", func() {
|
|||
lists map[string][]config.BytesSource
|
||||
downloader FileDownloader
|
||||
mockDownloader *MockDownloader
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
listCacheType = ListCacheTypeBlacklist
|
||||
|
||||
sutConfig, err = config.WithDefaults[config.SourceLoadingConfig]()
|
||||
|
@ -84,7 +89,7 @@ var _ = Describe("ListCache", func() {
|
|||
downloader = mockDownloader
|
||||
}
|
||||
|
||||
sut, err = NewListCache(listCacheType, sutConfig, lists, downloader)
|
||||
sut, err = NewListCache(ctx, listCacheType, sutConfig, lists, downloader)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
|
||||
|
@ -304,7 +309,7 @@ var _ = Describe("ListCache", func() {
|
|||
"gr1": config.NewBytesSources(file1, file2, file3),
|
||||
}
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
|
||||
sut, err := NewListCache(ctx, ListCacheTypeBlacklist, sutConfig, lists, downloader)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(lines1 + lines2 + lines3))
|
||||
|
@ -359,7 +364,7 @@ var _ = Describe("ListCache", func() {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
|
||||
_, err := NewListCache(ctx, ListCacheTypeBlacklist, sutConfig, lists, downloader)
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(err).Should(MatchError(parsers.ErrTooManyErrors))
|
||||
})
|
||||
|
@ -408,7 +413,7 @@ var _ = Describe("ListCache", func() {
|
|||
"gr2": {config.TextBytesSource("inline", "definition")},
|
||||
}
|
||||
|
||||
sut, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
|
||||
sut, err := NewListCache(ctx, ListCacheTypeBlacklist, sutConfig, lists, downloader)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
sut.LogConfig(logger)
|
||||
|
@ -430,7 +435,7 @@ var _ = Describe("ListCache", func() {
|
|||
"gr1": config.NewBytesSources("doesnotexist"),
|
||||
}
|
||||
|
||||
_, err := NewListCache(ListCacheTypeBlacklist, sutConfig, lists, downloader)
|
||||
_, err := NewListCache(ctx, ListCacheTypeBlacklist, sutConfig, lists, downloader)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
@ -43,20 +44,20 @@ type DatabaseWriter struct {
|
|||
dbFlushPeriod time.Duration
|
||||
}
|
||||
|
||||
func NewDatabaseWriter(dbType, target string, logRetentionDays uint64,
|
||||
func NewDatabaseWriter(ctx context.Context, dbType, target string, logRetentionDays uint64,
|
||||
dbFlushPeriod time.Duration,
|
||||
) (*DatabaseWriter, error) {
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
return newDatabaseWriter(mysql.Open(target), logRetentionDays, dbFlushPeriod)
|
||||
return newDatabaseWriter(ctx, mysql.Open(target), logRetentionDays, dbFlushPeriod)
|
||||
case "postgresql":
|
||||
return newDatabaseWriter(postgres.Open(target), logRetentionDays, dbFlushPeriod)
|
||||
return newDatabaseWriter(ctx, postgres.Open(target), logRetentionDays, dbFlushPeriod)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("incorrect database type provided: %s", dbType)
|
||||
}
|
||||
|
||||
func newDatabaseWriter(target gorm.Dialector, logRetentionDays uint64,
|
||||
func newDatabaseWriter(ctx context.Context, target gorm.Dialector, logRetentionDays uint64,
|
||||
dbFlushPeriod time.Duration,
|
||||
) (*DatabaseWriter, error) {
|
||||
db, err := gorm.Open(target, &gorm.Config{
|
||||
|
@ -84,7 +85,7 @@ func newDatabaseWriter(target gorm.Dialector, logRetentionDays uint64,
|
|||
dbFlushPeriod: dbFlushPeriod,
|
||||
}
|
||||
|
||||
go w.periodicFlush()
|
||||
go w.periodicFlush(ctx)
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
@ -118,16 +119,20 @@ func databaseMigration(db *gorm.DB) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *DatabaseWriter) periodicFlush() {
|
||||
func (d *DatabaseWriter) periodicFlush(ctx context.Context) {
|
||||
ticker := time.NewTicker(d.dbFlushPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
<-ticker.C
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := d.doDBWrite()
|
||||
|
||||
err := d.doDBWrite()
|
||||
util.LogOnError("can't write entries to the database: ", err)
|
||||
|
||||
util.LogOnError("can't write entries to the database: ", err)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
@ -20,19 +21,26 @@ import (
|
|||
var err error
|
||||
|
||||
var _ = Describe("DatabaseWriter", func() {
|
||||
var (
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
})
|
||||
Describe("Database query log to sqlite", func() {
|
||||
var (
|
||||
sqliteDB gorm.Dialector
|
||||
writer *DatabaseWriter
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
sqliteDB = sqlite.Open("file::memory:")
|
||||
})
|
||||
|
||||
When("New log entry was created", func() {
|
||||
BeforeEach(func() {
|
||||
writer, err = newDatabaseWriter(sqliteDB, 7, time.Millisecond)
|
||||
writer, err = newDatabaseWriter(ctx, sqliteDB, 7, time.Millisecond)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
db, err := writer.db.DB()
|
||||
|
@ -83,7 +91,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
|
||||
When("> 10000 Entries were created", func() {
|
||||
BeforeEach(func() {
|
||||
writer, err = newDatabaseWriter(sqliteDB, 7, time.Millisecond)
|
||||
writer, err = newDatabaseWriter(ctx, sqliteDB, 7, time.Millisecond)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
|
||||
|
@ -114,7 +122,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
|
||||
When("There are log entries with timestamp exceeding the retention period", func() {
|
||||
BeforeEach(func() {
|
||||
writer, err = newDatabaseWriter(sqliteDB, 1, time.Millisecond)
|
||||
writer, err = newDatabaseWriter(ctx, sqliteDB, 1, time.Millisecond)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
|
||||
|
@ -162,7 +170,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
Describe("Database query log fails", func() {
|
||||
When("mysql connection parameters wrong", func() {
|
||||
It("should be log with fatal", func() {
|
||||
_, err := NewDatabaseWriter("mysql", "wrong param", 7, 1)
|
||||
_, err := NewDatabaseWriter(ctx, "mysql", "wrong param", 7, 1)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(HavePrefix("can't create database connection"))
|
||||
})
|
||||
|
@ -170,7 +178,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
|
||||
When("postgresql connection parameters wrong", func() {
|
||||
It("should be log with fatal", func() {
|
||||
_, err := NewDatabaseWriter("postgresql", "wrong param", 7, 1)
|
||||
_, err := NewDatabaseWriter(ctx, "postgresql", "wrong param", 7, 1)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(HavePrefix("can't create database connection"))
|
||||
})
|
||||
|
@ -178,7 +186,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
|
||||
When("invalid database type is specified", func() {
|
||||
It("should be log with fatal", func() {
|
||||
_, err := NewDatabaseWriter("invalidsql", "", 7, 1)
|
||||
_, err := NewDatabaseWriter(ctx, "invalidsql", "", 7, 1)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(HavePrefix("incorrect database type provided"))
|
||||
})
|
||||
|
@ -225,7 +233,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
mock.ExpectExec(`ALTER TABLE log_entries ADD column if not exists id serial primary key`).WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
})
|
||||
|
||||
_, err = newDatabaseWriter(dlc, 1, time.Millisecond)
|
||||
_, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
})
|
||||
|
@ -257,7 +265,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
mock.ExpectExec("ALTER TABLE `log_entries` ADD `id` INT PRIMARY KEY AUTO_INCREMENT").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
})
|
||||
|
||||
_, err = newDatabaseWriter(dlc, 1, time.Millisecond)
|
||||
_, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
})
|
||||
|
@ -273,7 +281,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
mock.ExpectExec("ALTER TABLE `log_entries` ADD `id` INT PRIMARY KEY AUTO_INCREMENT").WillReturnError(fmt.Errorf("error 1060: duplicate column name"))
|
||||
})
|
||||
|
||||
_, err = newDatabaseWriter(dlc, 1, time.Millisecond)
|
||||
_, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
|
||||
|
@ -286,7 +294,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
mock.ExpectExec("ALTER TABLE `log_entries` ADD `id` INT PRIMARY KEY AUTO_INCREMENT").WillReturnError(fmt.Errorf("error XXX: some index error"))
|
||||
})
|
||||
|
||||
_, err = newDatabaseWriter(dlc, 1, time.Millisecond)
|
||||
_, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("can't perform auto migration: error XXX: some index error"))
|
||||
})
|
||||
|
@ -298,7 +306,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
mock.ExpectExec("CREATE TABLE `log_entries`").WillReturnError(fmt.Errorf("error XXX: some db error"))
|
||||
})
|
||||
|
||||
_, err = newDatabaseWriter(dlc, 1, time.Millisecond)
|
||||
_, err = newDatabaseWriter(ctx, dlc, 1, time.Millisecond)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("can't perform auto migration: error XXX: some db error"))
|
||||
})
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
|
@ -110,8 +111,10 @@ func clientGroupsBlock(cfg config.BlockingConfig) map[string][]string {
|
|||
}
|
||||
|
||||
// NewBlockingResolver returns a new configured instance of the resolver
|
||||
func NewBlockingResolver(
|
||||
cfg config.BlockingConfig, redis *redis.Client, bootstrap *Bootstrap,
|
||||
func NewBlockingResolver(ctx context.Context,
|
||||
cfg config.BlockingConfig,
|
||||
redis *redis.Client,
|
||||
bootstrap *Bootstrap,
|
||||
) (r *BlockingResolver, err error) {
|
||||
blockHandler, err := createBlockHandler(cfg)
|
||||
if err != nil {
|
||||
|
@ -120,8 +123,10 @@ func NewBlockingResolver(
|
|||
|
||||
downloader := lists.NewDownloader(cfg.Loading.Downloads, bootstrap.NewHTTPTransport())
|
||||
|
||||
blacklistMatcher, blErr := lists.NewListCache(lists.ListCacheTypeBlacklist, cfg.Loading, cfg.BlackLists, downloader)
|
||||
whitelistMatcher, wlErr := lists.NewListCache(lists.ListCacheTypeWhitelist, cfg.Loading, cfg.WhiteLists, downloader)
|
||||
blacklistMatcher, blErr := lists.NewListCache(ctx, lists.ListCacheTypeBlacklist,
|
||||
cfg.Loading, cfg.BlackLists, downloader)
|
||||
whitelistMatcher, wlErr := lists.NewListCache(ctx, lists.ListCacheTypeWhitelist,
|
||||
cfg.Loading, cfg.WhiteLists, downloader)
|
||||
whitelistOnlyGroups := determineWhitelistOnlyGroups(&cfg)
|
||||
|
||||
err = multierror.Append(err, blErr, wlErr).ErrorOrNil()
|
||||
|
@ -145,14 +150,14 @@ func NewBlockingResolver(
|
|||
redisClient: redis,
|
||||
}
|
||||
|
||||
res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](expirationcache.Options{
|
||||
res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](ctx, expirationcache.Options{
|
||||
CleanupInterval: defaultBlockingCleanUpInterval,
|
||||
}, func(key string) (val *[]net.IP, ttl time.Duration) {
|
||||
return res.queryForFQIdentifierIPs(key)
|
||||
})
|
||||
|
||||
if res.redisClient != nil {
|
||||
setupRedisEnabledSubscriber(res)
|
||||
setupRedisEnabledSubscriber(ctx, res)
|
||||
}
|
||||
|
||||
err = evt.Bus().SubscribeOnce(evt.ApplicationStarted, func(_ ...string) {
|
||||
|
@ -166,20 +171,26 @@ func NewBlockingResolver(
|
|||
return res, nil
|
||||
}
|
||||
|
||||
func setupRedisEnabledSubscriber(c *BlockingResolver) {
|
||||
func setupRedisEnabledSubscriber(ctx context.Context, c *BlockingResolver) {
|
||||
go func() {
|
||||
for em := range c.redisClient.EnabledChannel {
|
||||
if em != nil {
|
||||
c.log().Debug("Received state from redis: ", em)
|
||||
for {
|
||||
select {
|
||||
case em := <-c.redisClient.EnabledChannel:
|
||||
if em != nil {
|
||||
c.log().Debug("Received state from redis: ", em)
|
||||
|
||||
if em.State {
|
||||
c.internalEnableBlocking()
|
||||
} else {
|
||||
err := c.internalDisableBlocking(em.Duration, em.Groups)
|
||||
if err != nil {
|
||||
c.log().Warn("Blocking couldn't be disabled:", err)
|
||||
if em.State {
|
||||
c.internalEnableBlocking()
|
||||
} else {
|
||||
err := c.internalDisableBlocking(em.Duration, em.Groups)
|
||||
if err != nil {
|
||||
c.log().Warn("Blocking couldn't be disabled:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -50,6 +51,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
sutConfig config.BlockingConfig
|
||||
m *mockResolver
|
||||
mockAnswer *dns.Msg
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -59,6 +62,9 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
|
@ -72,7 +78,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
sut, err = NewBlockingResolver(sutConfig, nil, systemResolverBootstrap)
|
||||
|
||||
sut, err = NewBlockingResolver(ctx, sutConfig, nil, systemResolverBootstrap)
|
||||
Expect(err).Should(Succeed())
|
||||
sut.Next(m)
|
||||
})
|
||||
|
@ -113,7 +120,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
Expect(err).Should(Succeed())
|
||||
|
||||
// recreate to trigger a reload
|
||||
sut, err = NewBlockingResolver(sutConfig, nil, systemResolverBootstrap)
|
||||
sut, err = NewBlockingResolver(ctx, sutConfig, nil, systemResolverBootstrap)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
Eventually(groupCnt, "1s").Should(HaveLen(2))
|
||||
|
@ -1111,7 +1118,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
Describe("Create resolver with wrong parameter", func() {
|
||||
When("Wrong blockType is used", func() {
|
||||
It("should return error", func() {
|
||||
_, err := NewBlockingResolver(config.BlockingConfig{
|
||||
_, err := NewBlockingResolver(ctx, config.BlockingConfig{
|
||||
BlockType: "wrong",
|
||||
}, nil, systemResolverBootstrap)
|
||||
|
||||
|
@ -1121,7 +1128,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
When("strategy is failOnError", func() {
|
||||
It("should fail if lists can't be downloaded", func() {
|
||||
_, err := NewBlockingResolver(config.BlockingConfig{
|
||||
_, err := NewBlockingResolver(ctx, config.BlockingConfig{
|
||||
BlackLists: map[string][]config.BytesSource{"gr1": config.NewBytesSources("wrongPath")},
|
||||
WhiteLists: map[string][]config.BytesSource{"whitelist": config.NewBytesSources("wrongPath")},
|
||||
Loading: config.SourceLoadingConfig{Strategy: config.StartStrategyTypeFailOnError},
|
||||
|
@ -1155,7 +1162,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
BlockTTL: config.Duration(time.Minute),
|
||||
}
|
||||
|
||||
sut, err = NewBlockingResolver(sutConfig, redisClient, systemResolverBootstrap)
|
||||
sut, err = NewBlockingResolver(ctx, sutConfig, redisClient, systemResolverBootstrap)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
JustAfterEach(func() {
|
||||
|
|
|
@ -42,7 +42,7 @@ type Bootstrap struct {
|
|||
|
||||
// NewBootstrap creates and returns a new Bootstrap.
|
||||
// Internally, it uses a CachingResolver and an UpstreamResolver.
|
||||
func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
|
||||
func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err error) {
|
||||
logger := log.PrefixedLog("bootstrap")
|
||||
|
||||
timeout := defaultTimeout
|
||||
|
@ -97,7 +97,8 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
|
|||
|
||||
b.resolver = Chain(
|
||||
NewFilteringResolver(cfg.Filtering),
|
||||
newCachingResolver(cachingCfg, nil, false), // false: no metrics, to not overwrite the main blocking resolver ones
|
||||
// false: no metrics, to not overwrite the main blocking resolver ones
|
||||
newCachingResolver(ctx, cachingCfg, nil, false),
|
||||
parallelResolver,
|
||||
)
|
||||
|
||||
|
|
|
@ -25,6 +25,8 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
var (
|
||||
sut *Bootstrap
|
||||
sutConfig *config.Config
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
|
||||
err error
|
||||
)
|
||||
|
@ -41,10 +43,13 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
sut, err = NewBootstrap(sutConfig)
|
||||
sut, err = NewBootstrap(ctx, sutConfig)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
|
||||
|
@ -98,7 +103,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := NewBootstrap(&cfg)
|
||||
_, err := NewBootstrap(ctx, &cfg)
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
})
|
||||
})
|
||||
|
@ -140,7 +145,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := NewBootstrap(&cfg)
|
||||
_, err := NewBootstrap(ctx, &cfg)
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(err.Error()).Should(ContainSubstring("must use IP instead of hostname"))
|
||||
})
|
||||
|
@ -184,7 +189,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := NewBootstrap(&cfg)
|
||||
_, err := NewBootstrap(ctx, &cfg)
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(err.Error()).Should(ContainSubstring("no IPs configured"))
|
||||
})
|
||||
|
@ -203,7 +208,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := NewBootstrap(&cfg)
|
||||
_, err := NewBootstrap(ctx, &cfg)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync/atomic"
|
||||
|
@ -35,11 +36,18 @@ type CachingResolver struct {
|
|||
}
|
||||
|
||||
// NewCachingResolver creates a new resolver instance
|
||||
func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) *CachingResolver {
|
||||
return newCachingResolver(cfg, redis, true)
|
||||
func NewCachingResolver(ctx context.Context,
|
||||
cfg config.CachingConfig,
|
||||
redis *redis.Client,
|
||||
) *CachingResolver {
|
||||
return newCachingResolver(ctx, cfg, redis, true)
|
||||
}
|
||||
|
||||
func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetricEvents bool) *CachingResolver {
|
||||
func newCachingResolver(ctx context.Context,
|
||||
cfg config.CachingConfig,
|
||||
redis *redis.Client,
|
||||
emitMetricEvents bool,
|
||||
) *CachingResolver {
|
||||
c := &CachingResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("caching"),
|
||||
|
@ -48,17 +56,17 @@ func newCachingResolver(cfg config.CachingConfig, redis *redis.Client, emitMetri
|
|||
emitMetricEvents: emitMetricEvents,
|
||||
}
|
||||
|
||||
configureCaches(c, &cfg)
|
||||
configureCaches(ctx, c, &cfg)
|
||||
|
||||
if c.redisClient != nil {
|
||||
setupRedisCacheSubscriber(c)
|
||||
setupRedisCacheSubscriber(ctx, c)
|
||||
c.redisClient.GetRedisCache()
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
|
||||
func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.CachingConfig) {
|
||||
options := expirationcache.Options{
|
||||
CleanupInterval: defaultCachingCleanUpInterval,
|
||||
MaxSize: uint(cfg.MaxItemsCount),
|
||||
|
@ -91,9 +99,9 @@ func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
|
|||
},
|
||||
}
|
||||
|
||||
c.resultCache = expirationcache.NewPrefetchingCache(prefetchingOptions)
|
||||
c.resultCache = expirationcache.NewPrefetchingCache(ctx, prefetchingOptions)
|
||||
} else {
|
||||
c.resultCache = expirationcache.NewCache[dns.Msg](options)
|
||||
c.resultCache = expirationcache.NewCache[dns.Msg](ctx, options)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -117,13 +125,19 @@ func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*dns.Msg, time.Dura
|
|||
return nil, 0
|
||||
}
|
||||
|
||||
func setupRedisCacheSubscriber(c *CachingResolver) {
|
||||
func setupRedisCacheSubscriber(ctx context.Context, c *CachingResolver) {
|
||||
go func() {
|
||||
for rc := range c.redisClient.CacheChannel {
|
||||
if rc != nil {
|
||||
c.log().Debug("Received key from redis: ", rc.Key)
|
||||
ttl := c.adjustTTLs(rc.Response.Res.Answer)
|
||||
c.putInCache(rc.Key, rc.Response, ttl, false)
|
||||
for {
|
||||
select {
|
||||
case rc := <-c.redisClient.CacheChannel:
|
||||
if rc != nil {
|
||||
c.log().Debug("Received key from redis: ", rc.Key)
|
||||
ttl := c.adjustTTLs(rc.Response.Res.Answer)
|
||||
c.putInCache(rc.Key, rc.Response, ttl, false)
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
@ -26,6 +27,8 @@ var _ = Describe("CachingResolver", func() {
|
|||
sutConfig config.CachingConfig
|
||||
m *mockResolver
|
||||
mockAnswer *dns.Msg
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -43,7 +46,10 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
sut = NewCachingResolver(sutConfig, nil)
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
sut = NewCachingResolver(ctx, sutConfig, nil)
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
sut.Next(m)
|
||||
|
@ -79,7 +85,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
It("should prefetch domain if query count > threshold", func() {
|
||||
// prepare resolver, set smaller caching times for testing
|
||||
prefetchThreshold := 5
|
||||
configureCaches(sut, &sutConfig)
|
||||
configureCaches(ctx, sut, &sutConfig)
|
||||
|
||||
domainPrefetched := make(chan bool, 1)
|
||||
prefetchHitDomain := make(chan bool, 1)
|
||||
|
@ -725,7 +731,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
}
|
||||
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 1000, A, "1.1.1.1")
|
||||
|
||||
sut = NewCachingResolver(sutConfig, redisClient)
|
||||
sut = NewCachingResolver(ctx, sutConfig, redisClient)
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
sut.Next(m)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -26,7 +27,7 @@ type ClientNamesResolver struct {
|
|||
}
|
||||
|
||||
// NewClientNamesResolver creates new resolver instance
|
||||
func NewClientNamesResolver(
|
||||
func NewClientNamesResolver(ctx context.Context,
|
||||
cfg config.ClientLookupConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
) (cr *ClientNamesResolver, err error) {
|
||||
var r Resolver
|
||||
|
@ -41,7 +42,7 @@ func NewClientNamesResolver(
|
|||
configurable: withConfig(&cfg),
|
||||
typed: withType("client_names"),
|
||||
|
||||
cache: expirationcache.NewCache[[]string](expirationcache.Options{
|
||||
cache: expirationcache.NewCache[[]string](ctx, expirationcache.Options{
|
||||
CleanupInterval: time.Hour,
|
||||
}),
|
||||
externalResolver: r,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
|
@ -21,6 +22,8 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
sut *ClientNamesResolver
|
||||
sutConfig config.ClientLookupConfig
|
||||
m *mockResolver
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -31,8 +34,10 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
|
||||
JustBeforeEach(func() {
|
||||
var err error
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
sut, err = NewClientNamesResolver(sutConfig, nil, false)
|
||||
sut, err = NewClientNamesResolver(ctx, sutConfig, nil, false)
|
||||
Expect(err).Should(Succeed())
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
|
||||
|
@ -361,9 +366,9 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
Describe("Connstruction", func() {
|
||||
When("upstream is invalid", func() {
|
||||
It("errors during construction", func() {
|
||||
b := newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
|
||||
r, err := NewClientNamesResolver(config.ClientLookupConfig{
|
||||
r, err := NewClientNamesResolver(ctx, config.ClientLookupConfig{
|
||||
Upstream: config.Upstream{Host: "example.com"},
|
||||
}, b, true)
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
|
@ -184,7 +186,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
|
||||
When("upstream is invalid", func() {
|
||||
It("errors during construction", func() {
|
||||
b := newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
|
||||
r, err := NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{
|
||||
Mapping: config.ConditionalUpstreamMapping{
|
||||
|
|
|
@ -34,7 +34,10 @@ type HostsFileResolver struct {
|
|||
downloader lists.FileDownloader
|
||||
}
|
||||
|
||||
func NewHostsFileResolver(cfg config.HostsFileConfig, bootstrap *Bootstrap) (*HostsFileResolver, error) {
|
||||
func NewHostsFileResolver(ctx context.Context,
|
||||
cfg config.HostsFileConfig,
|
||||
bootstrap *Bootstrap,
|
||||
) (*HostsFileResolver, error) {
|
||||
r := HostsFileResolver{
|
||||
configurable: withConfig(&cfg),
|
||||
typed: withType("hosts_file"),
|
||||
|
@ -42,7 +45,7 @@ func NewHostsFileResolver(cfg config.HostsFileConfig, bootstrap *Bootstrap) (*Ho
|
|||
downloader: lists.NewDownloader(cfg.Loading.Downloads, bootstrap.NewHTTPTransport()),
|
||||
}
|
||||
|
||||
err := cfg.Loading.StartPeriodicRefresh(r.loadSources, func(err error) {
|
||||
err := cfg.Loading.StartPeriodicRefresh(ctx, r.loadSources, func(err error) {
|
||||
r.log().WithError(err).Errorf("could not load hosts files")
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -53,7 +53,9 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
JustBeforeEach(func() {
|
||||
var err error
|
||||
|
||||
sut, err = NewHostsFileResolver(sutConfig, systemResolverBootstrap)
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
sut, err = NewHostsFileResolver(ctx, sutConfig, systemResolverBootstrap)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
m = &mockResolver{}
|
||||
|
|
|
@ -118,10 +118,10 @@ func autoAnswer(qType dns.Type, qName string) (*dns.Msg, error) {
|
|||
}
|
||||
|
||||
// newTestBootstrap creates a test Bootstrap
|
||||
func newTestBootstrap(response *dns.Msg) *Bootstrap {
|
||||
func newTestBootstrap(ctx context.Context, response *dns.Msg) *Bootstrap {
|
||||
bootstrapUpstream := &mockResolver{}
|
||||
|
||||
b, err := NewBootstrap(&config.Config{})
|
||||
b, err := NewBootstrap(ctx, &config.Config{})
|
||||
util.FatalOnError("can't create bootstrap", err)
|
||||
|
||||
b.resolver = bootstrapUpstream
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -24,6 +25,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
sut *ParallelBestResolver
|
||||
sutMapping config.UpstreamGroups
|
||||
sutVerify bool
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
|
||||
err error
|
||||
|
||||
|
@ -37,6 +40,9 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
sutMapping = config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {
|
||||
config.Upstream{
|
||||
|
@ -111,7 +117,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
|
||||
When("no upstream resolvers can be reached", func() {
|
||||
BeforeEach(func() {
|
||||
bootstrap = newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
bootstrap = newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
|
||||
sutMapping = config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {
|
||||
|
@ -328,7 +334,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
|
||||
When("upstream is invalid", func() {
|
||||
It("errors during construction", func() {
|
||||
b := newTestBootstrap(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
|
||||
r, err := NewParallelBestResolver(config.UpstreamsConfig{
|
||||
Groups: config.UpstreamGroups{"test": {{Host: "example.com"}}},
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -29,7 +30,7 @@ type QueryLoggingResolver struct {
|
|||
}
|
||||
|
||||
// NewQueryLoggingResolver returns a new resolver instance
|
||||
func NewQueryLoggingResolver(cfg config.QueryLogConfig) *QueryLoggingResolver {
|
||||
func NewQueryLoggingResolver(ctx context.Context, cfg config.QueryLogConfig) *QueryLoggingResolver {
|
||||
logger := log.PrefixedLog(queryLoggingResolverType)
|
||||
|
||||
var writer querylog.Writer
|
||||
|
@ -43,10 +44,10 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) *QueryLoggingResolver {
|
|||
case config.QueryLogTypeCsvClient:
|
||||
writer, err = querylog.NewCSVWriter(cfg.Target, true, cfg.LogRetentionDays)
|
||||
case config.QueryLogTypeMysql:
|
||||
writer, err = querylog.NewDatabaseWriter("mysql", cfg.Target, cfg.LogRetentionDays,
|
||||
writer, err = querylog.NewDatabaseWriter(ctx, "mysql", cfg.Target, cfg.LogRetentionDays,
|
||||
cfg.FlushInterval.ToDuration())
|
||||
case config.QueryLogTypePostgresql:
|
||||
writer, err = querylog.NewDatabaseWriter("postgresql", cfg.Target, cfg.LogRetentionDays,
|
||||
writer, err = querylog.NewDatabaseWriter(ctx, "postgresql", cfg.Target, cfg.LogRetentionDays,
|
||||
cfg.FlushInterval.ToDuration())
|
||||
case config.QueryLogTypeConsole:
|
||||
writer = querylog.NewLoggerWriter()
|
||||
|
@ -81,23 +82,27 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) *QueryLoggingResolver {
|
|||
writer: writer,
|
||||
}
|
||||
|
||||
go resolver.writeLog()
|
||||
go resolver.writeLog(ctx)
|
||||
|
||||
if cfg.LogRetentionDays > 0 {
|
||||
go resolver.periodicCleanUp()
|
||||
go resolver.periodicCleanUp(ctx)
|
||||
}
|
||||
|
||||
return &resolver
|
||||
}
|
||||
|
||||
// triggers periodically cleanup of old log files
|
||||
func (r *QueryLoggingResolver) periodicCleanUp() {
|
||||
func (r *QueryLoggingResolver) periodicCleanUp(ctx context.Context) {
|
||||
ticker := time.NewTicker(cleanUpRunPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
<-ticker.C
|
||||
r.doCleanUp()
|
||||
select {
|
||||
case <-ticker.C:
|
||||
r.doCleanUp()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -164,18 +169,23 @@ func (r *QueryLoggingResolver) createLogEntry(request *model.Request, response *
|
|||
}
|
||||
|
||||
// write entry: if log directory is configured, write to log file
|
||||
func (r *QueryLoggingResolver) writeLog() {
|
||||
for logEntry := range r.logChan {
|
||||
start := time.Now()
|
||||
func (r *QueryLoggingResolver) writeLog(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case logEntry := <-r.logChan:
|
||||
start := time.Now()
|
||||
|
||||
r.writer.Write(logEntry)
|
||||
r.writer.Write(logEntry)
|
||||
|
||||
halfCap := cap(r.logChan) / 2 //nolint:gomnd
|
||||
halfCap := cap(r.logChan) / 2 //nolint:gomnd
|
||||
|
||||
// if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.)
|
||||
if len(r.logChan) > halfCap {
|
||||
r.log().WithField("channel_len",
|
||||
len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds())
|
||||
// if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.)
|
||||
if len(r.logChan) > halfCap {
|
||||
r.log().WithField("channel_len",
|
||||
len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package resolver
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -63,8 +64,10 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
sutConfig.SetDefaults() // not called when using a struct literal
|
||||
}
|
||||
|
||||
sut = NewQueryLoggingResolver(sutConfig)
|
||||
DeferCleanup(func() { close(sut.logChan) })
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
sut = NewQueryLoggingResolver(ctx, sutConfig)
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil)
|
||||
sut.Next(m)
|
||||
|
@ -151,7 +154,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
g.Expect(csvLines[0][7]).Should(Equal("NOERROR"))
|
||||
g.Expect(csvLines[0][8]).Should(Equal("RESOLVED"))
|
||||
g.Expect(csvLines[0][9]).Should(Equal("A"))
|
||||
}, "1s").Should(Succeed())
|
||||
}).Should(Succeed())
|
||||
})
|
||||
|
||||
By("check log for client2", func() {
|
||||
|
@ -169,7 +172,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
g.Expect(csvLines[0][7]).Should(Equal("NOERROR"))
|
||||
g.Expect(csvLines[0][8]).Should(Equal("RESOLVED"))
|
||||
g.Expect(csvLines[0][9]).Should(Equal("A"))
|
||||
}, "1s").Should(Succeed())
|
||||
}).Should(Succeed())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -109,16 +110,24 @@ var _ = Describe("Resolver", func() {
|
|||
})
|
||||
|
||||
Describe("Name", func() {
|
||||
var (
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
})
|
||||
When("'Name' is called", func() {
|
||||
It("should return resolver name", func() {
|
||||
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
|
||||
br, _ := NewBlockingResolver(ctx, config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
|
||||
name := Name(br)
|
||||
Expect(name).Should(Equal("blocking"))
|
||||
})
|
||||
})
|
||||
When("'Name' is called on a NamedResolver", func() {
|
||||
It("should return its custom name", func() {
|
||||
br, _ := NewBlockingResolver(config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
|
||||
br, _ := NewBlockingResolver(ctx, config.BlockingConfig{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
|
||||
|
||||
cfg := config.RewriterConfig{Rewrite: map[string]string{"not": "empty"}}
|
||||
r := NewRewriterResolver(cfg, br)
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
|
@ -113,7 +114,7 @@ func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
|
|||
// NewServer creates new server instance with passed config
|
||||
//
|
||||
//nolint:funlen
|
||||
func NewServer(cfg *config.Config) (server *Server, err error) {
|
||||
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
|
||||
log.ConfigureLogger(&cfg.Log)
|
||||
|
||||
var cert tls.Certificate
|
||||
|
@ -145,7 +146,7 @@ func NewServer(cfg *config.Config) (server *Server, err error) {
|
|||
|
||||
metrics.RegisterEventListeners()
|
||||
|
||||
bootstrap, err := resolver.NewBootstrap(cfg)
|
||||
bootstrap, err := resolver.NewBootstrap(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -155,7 +156,7 @@ func NewServer(cfg *config.Config) (server *Server, err error) {
|
|||
return nil, redisErr
|
||||
}
|
||||
|
||||
queryResolver, queryError := createQueryResolver(cfg, bootstrap, redisClient)
|
||||
queryResolver, queryError := createQueryResolver(ctx, cfg, bootstrap, redisClient)
|
||||
if queryError != nil {
|
||||
return nil, queryError
|
||||
}
|
||||
|
@ -385,6 +386,7 @@ func createSelfSignedCert() (tls.Certificate, error) {
|
|||
}
|
||||
|
||||
func createQueryResolver(
|
||||
ctx context.Context,
|
||||
cfg *config.Config,
|
||||
bootstrap *resolver.Bootstrap,
|
||||
redisClient *redis.Client,
|
||||
|
@ -396,10 +398,10 @@ func createQueryResolver(
|
|||
|
||||
upstreamTree, utErr := resolver.NewUpstreamTreeResolver(cfg.Upstreams, upstreamBranches)
|
||||
|
||||
blocking, blErr := resolver.NewBlockingResolver(cfg.Blocking, redisClient, bootstrap)
|
||||
clientNames, cnErr := resolver.NewClientNamesResolver(cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
|
||||
blocking, blErr := resolver.NewBlockingResolver(ctx, cfg.Blocking, redisClient, bootstrap)
|
||||
clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
|
||||
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
|
||||
hostsFile, hfErr := resolver.NewHostsFileResolver(cfg.HostsFile, bootstrap)
|
||||
hostsFile, hfErr := resolver.NewHostsFileResolver(ctx, cfg.HostsFile, bootstrap)
|
||||
|
||||
err = multierror.Append(
|
||||
multierror.Prefix(utErr, "upstream tree resolver: "),
|
||||
|
@ -417,12 +419,12 @@ func createQueryResolver(
|
|||
resolver.NewFqdnOnlyResolver(cfg.FqdnOnly),
|
||||
clientNames,
|
||||
resolver.NewEdeResolver(cfg.Ede),
|
||||
resolver.NewQueryLoggingResolver(cfg.QueryLog),
|
||||
resolver.NewQueryLoggingResolver(ctx, cfg.QueryLog),
|
||||
resolver.NewMetricsResolver(cfg.Prometheus),
|
||||
resolver.NewRewriterResolver(cfg.CustomDNS.RewriterConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)),
|
||||
hostsFile,
|
||||
blocking,
|
||||
resolver.NewCachingResolver(cfg.Caching, redisClient),
|
||||
resolver.NewCachingResolver(ctx, cfg.Caching, redisClient),
|
||||
resolver.NewRewriterResolver(cfg.Conditional.RewriterConfig, condUpstream),
|
||||
resolver.NewSpecialUseDomainNamesResolver(cfg.SUDN),
|
||||
upstreamTree,
|
||||
|
@ -513,7 +515,7 @@ const (
|
|||
)
|
||||
|
||||
// Start starts the server
|
||||
func (s *Server) Start(errCh chan<- error) {
|
||||
func (s *Server) Start(ctx context.Context, errCh chan<- error) {
|
||||
logger().Info("Starting server")
|
||||
|
||||
for _, srv := range s.dnsServers {
|
||||
|
@ -572,7 +574,7 @@ func (s *Server) Start(errCh chan<- error) {
|
|||
}()
|
||||
}
|
||||
|
||||
registerPrintConfigurationTrigger(s)
|
||||
registerPrintConfigurationTrigger(ctx, s)
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
|
|
|
@ -4,19 +4,25 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func registerPrintConfigurationTrigger(s *Server) {
|
||||
func registerPrintConfigurationTrigger(ctx context.Context, s *Server) {
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, syscall.SIGUSR1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
<-signals
|
||||
s.printConfiguration()
|
||||
select {
|
||||
case <-signals:
|
||||
s.printConfiguration()
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -32,6 +33,8 @@ var (
|
|||
|
||||
var _ = BeforeSuite(func() {
|
||||
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
googleMockUpstream := resolver.NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
if request.Question[0].Name == "error." {
|
||||
return nil
|
||||
|
@ -102,7 +105,7 @@ var _ = BeforeSuite(func() {
|
|||
Expect(youtubeFile.Error).Should(Succeed())
|
||||
|
||||
// create server
|
||||
sut, err = NewServer(&config.Config{
|
||||
sut, err = NewServer(ctx, &config.Config{
|
||||
CustomDNS: config.CustomDNSConfig{
|
||||
CustomTTL: config.Duration(3600 * time.Second),
|
||||
Mapping: config.CustomDNSMapping{
|
||||
|
@ -169,7 +172,7 @@ var _ = BeforeSuite(func() {
|
|||
|
||||
// start server
|
||||
go func() {
|
||||
sut.Start(errChan)
|
||||
sut.Start(ctx, errChan)
|
||||
}()
|
||||
DeferCleanup(sut.Stop)
|
||||
|
||||
|
@ -177,6 +180,14 @@ var _ = BeforeSuite(func() {
|
|||
})
|
||||
|
||||
var _ = Describe("Running DNS server", func() {
|
||||
var (
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
})
|
||||
Describe("performing DNS request with running server", func() {
|
||||
BeforeEach(func() {
|
||||
mockClientName.Store("")
|
||||
|
@ -555,14 +566,14 @@ var _ = Describe("Running DNS server", func() {
|
|||
})
|
||||
When("Server is created", func() {
|
||||
It("is created without redis connection", func() {
|
||||
_, err := NewServer(&cfg)
|
||||
_, err := NewServer(ctx, &cfg)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
It("can't be created if redis server is unavailable", func() {
|
||||
cfg.Redis.Required = true
|
||||
|
||||
_, err := NewServer(&cfg)
|
||||
_, err := NewServer(ctx, &cfg)
|
||||
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
})
|
||||
|
@ -573,7 +584,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
When("Server start is called", func() {
|
||||
It("start was called 2 times, start should fail", func() {
|
||||
// create server
|
||||
server, err := NewServer(&config.Config{
|
||||
server, err := NewServer(ctx, &config.Config{
|
||||
Upstreams: config.UpstreamsConfig{
|
||||
Groups: map[string][]config.Upstream{
|
||||
"default": {config.Upstream{Net: config.NetProtocolTcpUdp, Host: "4.4.4.4", Port: 53}},
|
||||
|
@ -598,14 +609,14 @@ var _ = Describe("Running DNS server", func() {
|
|||
errChan := make(chan error, 10)
|
||||
|
||||
// start server
|
||||
go server.Start(errChan)
|
||||
go server.Start(ctx, errChan)
|
||||
|
||||
DeferCleanup(server.Stop)
|
||||
|
||||
Consistently(errChan, "1s").ShouldNot(Receive())
|
||||
|
||||
// start again -> should fail
|
||||
server.Start(errChan)
|
||||
server.Start(ctx, errChan)
|
||||
|
||||
Eventually(errChan).Should(Receive())
|
||||
})
|
||||
|
@ -615,7 +626,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
When("Stop is called", func() {
|
||||
It("stop was called 2 times, start should fail", func() {
|
||||
// create server
|
||||
server, err := NewServer(&config.Config{
|
||||
server, err := NewServer(ctx, &config.Config{
|
||||
Upstreams: config.UpstreamsConfig{
|
||||
Groups: map[string][]config.Upstream{
|
||||
"default": {config.Upstream{Net: config.NetProtocolTcpUdp, Host: "4.4.4.4", Port: 53}},
|
||||
|
@ -641,7 +652,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
|
||||
// start server
|
||||
go func() {
|
||||
server.Start(errChan)
|
||||
server.Start(ctx, errChan)
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
@ -681,7 +692,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
Describe("create query resolver", func() {
|
||||
When("some upstream returns error", func() {
|
||||
It("create query resolver should return error", func() {
|
||||
r, err := createQueryResolver(&config.Config{
|
||||
r, err := createQueryResolver(ctx, &config.Config{
|
||||
StartVerifyUpstream: true,
|
||||
Upstreams: config.UpstreamsConfig{
|
||||
Groups: config.UpstreamGroups{
|
||||
|
@ -736,8 +747,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
cfg.Ports = config.PortsConfig{
|
||||
HTTPS: []string{":14443"},
|
||||
}
|
||||
|
||||
sut, err := NewServer(&cfg)
|
||||
sut, err := NewServer(ctx, &cfg)
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(sut.cert.Certificate).ShouldNot(BeNil())
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue