refactor(util): make `LogOnError` get the log from a `Context`

This commit is contained in:
ThinkChaos 2024-01-26 22:07:26 -05:00
parent b335887992
commit 3fcf379df7
9 changed files with 42 additions and 30 deletions

View File

@ -63,7 +63,7 @@ func startServer(_ *cobra.Command, _ []string) error {
select {
case <-signals:
log.Log().Infof("Terminating...")
util.LogOnError("can't stop server: ", srv.Stop(ctx))
util.LogOnError(ctx, "can't stop server: ", srv.Stop(ctx))
done <- true
case err := <-errChan:

View File

@ -128,7 +128,7 @@ func (d *DatabaseWriter) periodicFlush(ctx context.Context) {
case <-ticker.C:
err := d.doDBWrite()
util.LogOnError("can't write entries to the database: ", err)
util.LogOnError(ctx, "can't write entries to the database: ", err)
case <-ctx.Done():
return

View File

@ -125,7 +125,7 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string)
return &packed, r.adjustTTLs(response.Res.Answer)
}
} else {
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
util.LogOnError(ctx, fmt.Sprintf("can't prefetch '%s' ", domainName), err)
}
return nil, 0
@ -140,7 +140,7 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) {
if rc != nil {
logger.Debug("Received key from redis: ", rc.Key)
ttl := r.adjustTTLs(rc.Response.Res.Answer)
r.putInCache(rc.Key, rc.Response, ttl, false)
r.putInCache(ctx, rc.Key, rc.Response, ttl, false)
}
case <-ctx.Done():
@ -194,7 +194,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (
if err == nil {
cacheTTL := r.adjustTTLs(response.Res.Answer)
r.putInCache(cacheKey, response, cacheTTL, true)
r.putInCache(ctx, cacheKey, response, cacheTTL, true)
}
}
@ -250,8 +250,8 @@ func isResponseCacheable(msg *dns.Msg) bool {
return !msg.Truncated && !msg.CheckingDisabled
}
func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, ttl time.Duration,
publish bool,
func (r *CachingResolver) putInCache(
ctx context.Context, cacheKey string, response *model.Response, ttl time.Duration, publish bool,
) {
respCopy := response.Res.Copy()
@ -259,7 +259,7 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response,
util.RemoveEdns0Record(respCopy)
packed, err := respCopy.Pack()
util.LogOnError("error on packing", err)
util.LogOnError(ctx, "error on packing", err)
if err == nil {
if response.Res.Rcode == dns.RcodeSuccess && isResponseCacheable(response.Res) {

View File

@ -5,11 +5,13 @@ import (
"errors"
"fmt"
"math"
"math/rand"
"strings"
"sync/atomic"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
@ -161,7 +163,7 @@ func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Reque
ctx, cancel := context.WithCancel(ctx)
defer cancel() // abort requests to resolvers that lost the race
resolvers := pickRandom(allResolvers, r.resolverCount)
resolvers := pickRandom(ctx, allResolvers, r.resolverCount)
ch := make(chan requestResponse, len(resolvers))
for _, resolver := range resolvers {
@ -210,7 +212,7 @@ func (r *ParallelBestResolver) retryWithDifferent(
ctx context.Context, logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus,
) (*model.Response, error) {
// second try (if retryWithDifferentResolver == true)
resolver := weightedRandom(*r.resolvers.Load(), resolvers)
resolver := weightedRandom(ctx, *r.resolvers.Load(), resolvers)
logger.Debugf("using %s as second resolver", resolver.resolver)
resp, err := resolver.resolve(ctx, request)
@ -227,17 +229,17 @@ func (r *ParallelBestResolver) retryWithDifferent(
}
// pickRandom picks n (resolverCount) different random resolvers from the given resolver pool
func pickRandom(resolvers []*upstreamResolverStatus, resolverCount int) []*upstreamResolverStatus {
func pickRandom(ctx context.Context, resolvers []*upstreamResolverStatus, resolverCount int) []*upstreamResolverStatus {
chosenResolvers := make([]*upstreamResolverStatus, 0, resolverCount)
for i := 0; i < resolverCount; i++ {
chosenResolvers = append(chosenResolvers, weightedRandom(resolvers, chosenResolvers))
chosenResolvers = append(chosenResolvers, weightedRandom(ctx, resolvers, chosenResolvers))
}
return chosenResolvers
}
func weightedRandom(in, excludedResolvers []*upstreamResolverStatus) *upstreamResolverStatus {
func weightedRandom(ctx context.Context, in, excludedResolvers []*upstreamResolverStatus) *upstreamResolverStatus {
const errorWindowInSec = 60
choices := make([]weightedrand.Choice[*upstreamResolverStatus, uint], 0, len(in))
@ -262,7 +264,13 @@ outer:
}
c, err := weightedrand.NewChooser(choices...)
util.LogOnError("can't choose random weighted resolver: ", err)
if err != nil {
log.FromCtx(ctx).WithError(err).Error("can't choose random weighted resolver, falling back to uniform random")
val := rand.Int() //nolint:gosec // pseudo-randomness is good enough
return choices[val%len(choices)].Item
}
return c.Pick()
}

View File

@ -301,12 +301,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}
})
It("should use 2 random peeked resolvers, weighted with last error timestamp", func() {
It("should use 2 random peeked resolvers, weighted with last error timestamp", func(ctx context.Context) {
By("all resolvers have same weight for random -> equal distribution", func() {
resolverCount := make(map[Resolver]int)
for i := 0; i < 1000; i++ {
resolvers := pickRandom(*sut.resolvers.Load(), parallelBestResolverCount)
resolvers := pickRandom(ctx, *sut.resolvers.Load(), parallelBestResolverCount)
res1 := resolvers[0].resolver
res2 := resolvers[1].resolver
Expect(res1).ShouldNot(Equal(res2))
@ -330,7 +330,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
resolverCount := make(map[*UpstreamResolver]int)
for i := 0; i < 100; i++ {
resolvers := pickRandom(*sut.resolvers.Load(), parallelBestResolverCount)
resolvers := pickRandom(ctx, *sut.resolvers.Load(), parallelBestResolverCount)
res1 := resolvers[0].resolver.(*UpstreamResolver)
res2 := resolvers[1].resolver.(*UpstreamResolver)
Expect(res1).ShouldNot(Equal(res2))
@ -493,12 +493,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}
})
It("should use 2 random peeked resolvers, weighted with last error timestamp", func() {
It("should use 2 random peeked resolvers, weighted with last error timestamp", func(ctx context.Context) {
By("all resolvers have same weight for random -> equal distribution", func() {
resolverCount := make(map[Resolver]int)
for i := 0; i < 2000; i++ {
r := weightedRandom(*sut.resolvers.Load(), nil)
r := weightedRandom(ctx, *sut.resolvers.Load(), nil)
resolverCount[r.resolver]++
}
for _, v := range resolverCount {
@ -517,7 +517,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
resolverCount := make(map[*UpstreamResolver]int)
for i := 0; i < 200; i++ {
r := weightedRandom(*sut.resolvers.Load(), nil)
r := weightedRandom(ctx, *sut.resolvers.Load(), nil)
res := r.resolver.(*UpstreamResolver)
resolverCount[res]++

View File

@ -162,7 +162,7 @@ func (r *httpUpstreamClient) callExternal(
}
defer func() {
util.LogOnError("can't close response body ", httpResponse.Body.Close())
util.LogOnError(ctx, "can't close response body ", httpResponse.Body.Close())
}()
if httpResponse.StatusCode != http.StatusOK {

View File

@ -426,7 +426,9 @@ func (s *Server) registerDNSHandlers(ctx context.Context) {
for _, server := range s.dnsServers {
handler := server.Handler.(*dns.ServeMux)
handler.HandleFunc(".", wrappedOnRequest)
handler.HandleFunc("healthcheck.blocky", s.OnHealthCheck)
handler.HandleFunc("healthcheck.blocky", func(w dns.ResponseWriter, m *dns.Msg) {
s.OnHealthCheck(ctx, w, m)
})
}
}
@ -606,10 +608,10 @@ func (s *Server) OnRequest(
m := new(dns.Msg)
m.SetRcode(request, dns.RcodeServerFailure)
err := w.WriteMsg(m)
util.LogOnError("can't write message: ", err)
util.LogOnError(ctx, "can't write message: ", err)
} else {
err := w.WriteMsg(response.Res)
util.LogOnError("can't write message: ", err)
util.LogOnError(ctx, "can't write message: ", err)
}
}
@ -672,13 +674,13 @@ func getMaxResponseSize(req *model.Request) int {
}
// OnHealthCheck Handler for docker health check. Just returns OK code without delegating to resolver chain
func (s *Server) OnHealthCheck(w dns.ResponseWriter, request *dns.Msg) {
func (s *Server) OnHealthCheck(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) {
resp := new(dns.Msg)
resp.SetReply(request)
resp.Rcode = dns.RcodeSuccess
err := w.WriteMsg(resp)
util.LogOnError("can't write message: ", err)
util.LogOnError(ctx, "can't write message: ", err)
}
func resolveClientIPAndProtocol(addr net.Addr) (ip net.IP, protocol model.RequestProtocol) {

View File

@ -1,6 +1,7 @@
package util
import (
"context"
"encoding/binary"
"fmt"
"io"
@ -154,9 +155,9 @@ func IterateValueSorted(in map[string]int, fn func(string, int)) {
}
// LogOnError logs the message only if error is not nil
func LogOnError(message string, err error) {
func LogOnError(ctx context.Context, message string, err error) {
if err != nil {
log.Log().Error(message, err)
log.FromCtx(ctx).Error(message, err)
}
}

View File

@ -1,6 +1,7 @@
package util
import (
"context"
"errors"
"fmt"
"net"
@ -153,11 +154,11 @@ var _ = Describe("Common function tests", func() {
Describe("Logging functions", func() {
When("LogOnError is called with error", func() {
err := errors.New("test")
It("should log", func() {
It("should log", func(ctx context.Context) {
hook := test.NewGlobal()
Log().AddHook(hook)
defer hook.Reset()
LogOnError("message ", err)
LogOnError(ctx, "message ", err)
Expect(hook.LastEntry().Message).Should(Equal("message test"))
})
})