diff --git a/cmd/serve.go b/cmd/serve.go index 03de155d..6520e3a4 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -63,7 +63,7 @@ func startServer(_ *cobra.Command, _ []string) error { select { case <-signals: log.Log().Infof("Terminating...") - util.LogOnError("can't stop server: ", srv.Stop(ctx)) + util.LogOnError(ctx, "can't stop server: ", srv.Stop(ctx)) done <- true case err := <-errChan: diff --git a/querylog/database_writer.go b/querylog/database_writer.go index f7856e7e..1dc40c9a 100644 --- a/querylog/database_writer.go +++ b/querylog/database_writer.go @@ -128,7 +128,7 @@ func (d *DatabaseWriter) periodicFlush(ctx context.Context) { case <-ticker.C: err := d.doDBWrite() - util.LogOnError("can't write entries to the database: ", err) + util.LogOnError(ctx, "can't write entries to the database: ", err) case <-ctx.Done(): return diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 51579890..ecc62eef 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -125,7 +125,7 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) return &packed, r.adjustTTLs(response.Res.Answer) } } else { - util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err) + util.LogOnError(ctx, fmt.Sprintf("can't prefetch '%s' ", domainName), err) } return nil, 0 @@ -140,7 +140,7 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) { if rc != nil { logger.Debug("Received key from redis: ", rc.Key) ttl := r.adjustTTLs(rc.Response.Res.Answer) - r.putInCache(rc.Key, rc.Response, ttl, false) + r.putInCache(ctx, rc.Key, rc.Response, ttl, false) } case <-ctx.Done(): @@ -194,7 +194,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) ( if err == nil { cacheTTL := r.adjustTTLs(response.Res.Answer) - r.putInCache(cacheKey, response, cacheTTL, true) + r.putInCache(ctx, cacheKey, response, cacheTTL, true) } } @@ -250,8 +250,8 @@ func isResponseCacheable(msg *dns.Msg) bool { return !msg.Truncated && !msg.CheckingDisabled } -func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, ttl time.Duration, - publish bool, +func (r *CachingResolver) putInCache( + ctx context.Context, cacheKey string, response *model.Response, ttl time.Duration, publish bool, ) { respCopy := response.Res.Copy() @@ -259,7 +259,7 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, util.RemoveEdns0Record(respCopy) packed, err := respCopy.Pack() - util.LogOnError("error on packing", err) + util.LogOnError(ctx, "error on packing", err) if err == nil { if response.Res.Rcode == dns.RcodeSuccess && isResponseCacheable(response.Res) { diff --git a/resolver/parallel_best_resolver.go b/resolver/parallel_best_resolver.go index e8dea8ed..9311e56b 100644 --- a/resolver/parallel_best_resolver.go +++ b/resolver/parallel_best_resolver.go @@ -5,11 +5,13 @@ import ( "errors" "fmt" "math" + "math/rand" "strings" "sync/atomic" "time" "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" @@ -161,7 +163,7 @@ func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Reque ctx, cancel := context.WithCancel(ctx) defer cancel() // abort requests to resolvers that lost the race - resolvers := pickRandom(allResolvers, r.resolverCount) + resolvers := pickRandom(ctx, allResolvers, r.resolverCount) ch := make(chan requestResponse, len(resolvers)) for _, resolver := range resolvers { @@ -210,7 +212,7 @@ func (r *ParallelBestResolver) retryWithDifferent( ctx context.Context, logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus, ) (*model.Response, error) { // second try (if retryWithDifferentResolver == true) - resolver := weightedRandom(*r.resolvers.Load(), resolvers) + resolver := weightedRandom(ctx, *r.resolvers.Load(), resolvers) logger.Debugf("using %s as second resolver", resolver.resolver) resp, err := resolver.resolve(ctx, request) @@ -227,17 +229,17 @@ func (r *ParallelBestResolver) retryWithDifferent( } // pickRandom picks n (resolverCount) different random resolvers from the given resolver pool -func pickRandom(resolvers []*upstreamResolverStatus, resolverCount int) []*upstreamResolverStatus { +func pickRandom(ctx context.Context, resolvers []*upstreamResolverStatus, resolverCount int) []*upstreamResolverStatus { chosenResolvers := make([]*upstreamResolverStatus, 0, resolverCount) for i := 0; i < resolverCount; i++ { - chosenResolvers = append(chosenResolvers, weightedRandom(resolvers, chosenResolvers)) + chosenResolvers = append(chosenResolvers, weightedRandom(ctx, resolvers, chosenResolvers)) } return chosenResolvers } -func weightedRandom(in, excludedResolvers []*upstreamResolverStatus) *upstreamResolverStatus { +func weightedRandom(ctx context.Context, in, excludedResolvers []*upstreamResolverStatus) *upstreamResolverStatus { const errorWindowInSec = 60 choices := make([]weightedrand.Choice[*upstreamResolverStatus, uint], 0, len(in)) @@ -262,7 +264,13 @@ outer: } c, err := weightedrand.NewChooser(choices...) - util.LogOnError("can't choose random weighted resolver: ", err) + if err != nil { + log.FromCtx(ctx).WithError(err).Error("can't choose random weighted resolver, falling back to uniform random") + + val := rand.Int() //nolint:gosec // pseudo-randomness is good enough + + return choices[val%len(choices)].Item + } return c.Pick() } diff --git a/resolver/parallel_best_resolver_test.go b/resolver/parallel_best_resolver_test.go index 631ec181..3e960d78 100644 --- a/resolver/parallel_best_resolver_test.go +++ b/resolver/parallel_best_resolver_test.go @@ -301,12 +301,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2} }) - It("should use 2 random peeked resolvers, weighted with last error timestamp", func() { + It("should use 2 random peeked resolvers, weighted with last error timestamp", func(ctx context.Context) { By("all resolvers have same weight for random -> equal distribution", func() { resolverCount := make(map[Resolver]int) for i := 0; i < 1000; i++ { - resolvers := pickRandom(*sut.resolvers.Load(), parallelBestResolverCount) + resolvers := pickRandom(ctx, *sut.resolvers.Load(), parallelBestResolverCount) res1 := resolvers[0].resolver res2 := resolvers[1].resolver Expect(res1).ShouldNot(Equal(res2)) @@ -330,7 +330,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { resolverCount := make(map[*UpstreamResolver]int) for i := 0; i < 100; i++ { - resolvers := pickRandom(*sut.resolvers.Load(), parallelBestResolverCount) + resolvers := pickRandom(ctx, *sut.resolvers.Load(), parallelBestResolverCount) res1 := resolvers[0].resolver.(*UpstreamResolver) res2 := resolvers[1].resolver.(*UpstreamResolver) Expect(res1).ShouldNot(Equal(res2)) @@ -493,12 +493,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2} }) - It("should use 2 random peeked resolvers, weighted with last error timestamp", func() { + It("should use 2 random peeked resolvers, weighted with last error timestamp", func(ctx context.Context) { By("all resolvers have same weight for random -> equal distribution", func() { resolverCount := make(map[Resolver]int) for i := 0; i < 2000; i++ { - r := weightedRandom(*sut.resolvers.Load(), nil) + r := weightedRandom(ctx, *sut.resolvers.Load(), nil) resolverCount[r.resolver]++ } for _, v := range resolverCount { @@ -517,7 +517,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { resolverCount := make(map[*UpstreamResolver]int) for i := 0; i < 200; i++ { - r := weightedRandom(*sut.resolvers.Load(), nil) + r := weightedRandom(ctx, *sut.resolvers.Load(), nil) res := r.resolver.(*UpstreamResolver) resolverCount[res]++ diff --git a/resolver/upstream_resolver.go b/resolver/upstream_resolver.go index f4759ae0..51cf8d08 100644 --- a/resolver/upstream_resolver.go +++ b/resolver/upstream_resolver.go @@ -162,7 +162,7 @@ func (r *httpUpstreamClient) callExternal( } defer func() { - util.LogOnError("can't close response body ", httpResponse.Body.Close()) + util.LogOnError(ctx, "can't close response body ", httpResponse.Body.Close()) }() if httpResponse.StatusCode != http.StatusOK { diff --git a/server/server.go b/server/server.go index fdccec89..58c69e30 100644 --- a/server/server.go +++ b/server/server.go @@ -426,7 +426,9 @@ func (s *Server) registerDNSHandlers(ctx context.Context) { for _, server := range s.dnsServers { handler := server.Handler.(*dns.ServeMux) handler.HandleFunc(".", wrappedOnRequest) - handler.HandleFunc("healthcheck.blocky", s.OnHealthCheck) + handler.HandleFunc("healthcheck.blocky", func(w dns.ResponseWriter, m *dns.Msg) { + s.OnHealthCheck(ctx, w, m) + }) } } @@ -606,10 +608,10 @@ func (s *Server) OnRequest( m := new(dns.Msg) m.SetRcode(request, dns.RcodeServerFailure) err := w.WriteMsg(m) - util.LogOnError("can't write message: ", err) + util.LogOnError(ctx, "can't write message: ", err) } else { err := w.WriteMsg(response.Res) - util.LogOnError("can't write message: ", err) + util.LogOnError(ctx, "can't write message: ", err) } } @@ -672,13 +674,13 @@ func getMaxResponseSize(req *model.Request) int { } // OnHealthCheck Handler for docker health check. Just returns OK code without delegating to resolver chain -func (s *Server) OnHealthCheck(w dns.ResponseWriter, request *dns.Msg) { +func (s *Server) OnHealthCheck(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) { resp := new(dns.Msg) resp.SetReply(request) resp.Rcode = dns.RcodeSuccess err := w.WriteMsg(resp) - util.LogOnError("can't write message: ", err) + util.LogOnError(ctx, "can't write message: ", err) } func resolveClientIPAndProtocol(addr net.Addr) (ip net.IP, protocol model.RequestProtocol) { diff --git a/util/common.go b/util/common.go index 546399b5..e16de3ad 100644 --- a/util/common.go +++ b/util/common.go @@ -1,6 +1,7 @@ package util import ( + "context" "encoding/binary" "fmt" "io" @@ -154,9 +155,9 @@ func IterateValueSorted(in map[string]int, fn func(string, int)) { } // LogOnError logs the message only if error is not nil -func LogOnError(message string, err error) { +func LogOnError(ctx context.Context, message string, err error) { if err != nil { - log.Log().Error(message, err) + log.FromCtx(ctx).Error(message, err) } } diff --git a/util/common_test.go b/util/common_test.go index 6ed0fc07..f539e5a2 100644 --- a/util/common_test.go +++ b/util/common_test.go @@ -1,6 +1,7 @@ package util import ( + "context" "errors" "fmt" "net" @@ -153,11 +154,11 @@ var _ = Describe("Common function tests", func() { Describe("Logging functions", func() { When("LogOnError is called with error", func() { err := errors.New("test") - It("should log", func() { + It("should log", func(ctx context.Context) { hook := test.NewGlobal() Log().AddHook(hook) defer hook.Reset() - LogOnError("message ", err) + LogOnError(ctx, "message ", err) Expect(hook.LastEntry().Message).Should(Equal("message test")) }) })