mirror of https://github.com/0xERR0R/blocky.git
refactor(util): make `LogOnError` get the log from a `Context`
This commit is contained in:
parent
b335887992
commit
3fcf379df7
|
@ -63,7 +63,7 @@ func startServer(_ *cobra.Command, _ []string) error {
|
|||
select {
|
||||
case <-signals:
|
||||
log.Log().Infof("Terminating...")
|
||||
util.LogOnError("can't stop server: ", srv.Stop(ctx))
|
||||
util.LogOnError(ctx, "can't stop server: ", srv.Stop(ctx))
|
||||
done <- true
|
||||
|
||||
case err := <-errChan:
|
||||
|
|
|
@ -128,7 +128,7 @@ func (d *DatabaseWriter) periodicFlush(ctx context.Context) {
|
|||
case <-ticker.C:
|
||||
err := d.doDBWrite()
|
||||
|
||||
util.LogOnError("can't write entries to the database: ", err)
|
||||
util.LogOnError(ctx, "can't write entries to the database: ", err)
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
|
|
@ -125,7 +125,7 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string)
|
|||
return &packed, r.adjustTTLs(response.Res.Answer)
|
||||
}
|
||||
} else {
|
||||
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
|
||||
util.LogOnError(ctx, fmt.Sprintf("can't prefetch '%s' ", domainName), err)
|
||||
}
|
||||
|
||||
return nil, 0
|
||||
|
@ -140,7 +140,7 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) {
|
|||
if rc != nil {
|
||||
logger.Debug("Received key from redis: ", rc.Key)
|
||||
ttl := r.adjustTTLs(rc.Response.Res.Answer)
|
||||
r.putInCache(rc.Key, rc.Response, ttl, false)
|
||||
r.putInCache(ctx, rc.Key, rc.Response, ttl, false)
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
|
@ -194,7 +194,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (
|
|||
|
||||
if err == nil {
|
||||
cacheTTL := r.adjustTTLs(response.Res.Answer)
|
||||
r.putInCache(cacheKey, response, cacheTTL, true)
|
||||
r.putInCache(ctx, cacheKey, response, cacheTTL, true)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -250,8 +250,8 @@ func isResponseCacheable(msg *dns.Msg) bool {
|
|||
return !msg.Truncated && !msg.CheckingDisabled
|
||||
}
|
||||
|
||||
func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, ttl time.Duration,
|
||||
publish bool,
|
||||
func (r *CachingResolver) putInCache(
|
||||
ctx context.Context, cacheKey string, response *model.Response, ttl time.Duration, publish bool,
|
||||
) {
|
||||
respCopy := response.Res.Copy()
|
||||
|
||||
|
@ -259,7 +259,7 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response,
|
|||
util.RemoveEdns0Record(respCopy)
|
||||
|
||||
packed, err := respCopy.Pack()
|
||||
util.LogOnError("error on packing", err)
|
||||
util.LogOnError(ctx, "error on packing", err)
|
||||
|
||||
if err == nil {
|
||||
if response.Res.Rcode == dns.RcodeSuccess && isResponseCacheable(response.Res) {
|
||||
|
|
|
@ -5,11 +5,13 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
|
||||
|
@ -161,7 +163,7 @@ func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Reque
|
|||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel() // abort requests to resolvers that lost the race
|
||||
|
||||
resolvers := pickRandom(allResolvers, r.resolverCount)
|
||||
resolvers := pickRandom(ctx, allResolvers, r.resolverCount)
|
||||
ch := make(chan requestResponse, len(resolvers))
|
||||
|
||||
for _, resolver := range resolvers {
|
||||
|
@ -210,7 +212,7 @@ func (r *ParallelBestResolver) retryWithDifferent(
|
|||
ctx context.Context, logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus,
|
||||
) (*model.Response, error) {
|
||||
// second try (if retryWithDifferentResolver == true)
|
||||
resolver := weightedRandom(*r.resolvers.Load(), resolvers)
|
||||
resolver := weightedRandom(ctx, *r.resolvers.Load(), resolvers)
|
||||
logger.Debugf("using %s as second resolver", resolver.resolver)
|
||||
|
||||
resp, err := resolver.resolve(ctx, request)
|
||||
|
@ -227,17 +229,17 @@ func (r *ParallelBestResolver) retryWithDifferent(
|
|||
}
|
||||
|
||||
// pickRandom picks n (resolverCount) different random resolvers from the given resolver pool
|
||||
func pickRandom(resolvers []*upstreamResolverStatus, resolverCount int) []*upstreamResolverStatus {
|
||||
func pickRandom(ctx context.Context, resolvers []*upstreamResolverStatus, resolverCount int) []*upstreamResolverStatus {
|
||||
chosenResolvers := make([]*upstreamResolverStatus, 0, resolverCount)
|
||||
|
||||
for i := 0; i < resolverCount; i++ {
|
||||
chosenResolvers = append(chosenResolvers, weightedRandom(resolvers, chosenResolvers))
|
||||
chosenResolvers = append(chosenResolvers, weightedRandom(ctx, resolvers, chosenResolvers))
|
||||
}
|
||||
|
||||
return chosenResolvers
|
||||
}
|
||||
|
||||
func weightedRandom(in, excludedResolvers []*upstreamResolverStatus) *upstreamResolverStatus {
|
||||
func weightedRandom(ctx context.Context, in, excludedResolvers []*upstreamResolverStatus) *upstreamResolverStatus {
|
||||
const errorWindowInSec = 60
|
||||
|
||||
choices := make([]weightedrand.Choice[*upstreamResolverStatus, uint], 0, len(in))
|
||||
|
@ -262,7 +264,13 @@ outer:
|
|||
}
|
||||
|
||||
c, err := weightedrand.NewChooser(choices...)
|
||||
util.LogOnError("can't choose random weighted resolver: ", err)
|
||||
if err != nil {
|
||||
log.FromCtx(ctx).WithError(err).Error("can't choose random weighted resolver, falling back to uniform random")
|
||||
|
||||
val := rand.Int() //nolint:gosec // pseudo-randomness is good enough
|
||||
|
||||
return choices[val%len(choices)].Item
|
||||
}
|
||||
|
||||
return c.Pick()
|
||||
}
|
||||
|
|
|
@ -301,12 +301,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}
|
||||
})
|
||||
|
||||
It("should use 2 random peeked resolvers, weighted with last error timestamp", func() {
|
||||
It("should use 2 random peeked resolvers, weighted with last error timestamp", func(ctx context.Context) {
|
||||
By("all resolvers have same weight for random -> equal distribution", func() {
|
||||
resolverCount := make(map[Resolver]int)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
resolvers := pickRandom(*sut.resolvers.Load(), parallelBestResolverCount)
|
||||
resolvers := pickRandom(ctx, *sut.resolvers.Load(), parallelBestResolverCount)
|
||||
res1 := resolvers[0].resolver
|
||||
res2 := resolvers[1].resolver
|
||||
Expect(res1).ShouldNot(Equal(res2))
|
||||
|
@ -330,7 +330,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
resolverCount := make(map[*UpstreamResolver]int)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
resolvers := pickRandom(*sut.resolvers.Load(), parallelBestResolverCount)
|
||||
resolvers := pickRandom(ctx, *sut.resolvers.Load(), parallelBestResolverCount)
|
||||
res1 := resolvers[0].resolver.(*UpstreamResolver)
|
||||
res2 := resolvers[1].resolver.(*UpstreamResolver)
|
||||
Expect(res1).ShouldNot(Equal(res2))
|
||||
|
@ -493,12 +493,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}
|
||||
})
|
||||
|
||||
It("should use 2 random peeked resolvers, weighted with last error timestamp", func() {
|
||||
It("should use 2 random peeked resolvers, weighted with last error timestamp", func(ctx context.Context) {
|
||||
By("all resolvers have same weight for random -> equal distribution", func() {
|
||||
resolverCount := make(map[Resolver]int)
|
||||
|
||||
for i := 0; i < 2000; i++ {
|
||||
r := weightedRandom(*sut.resolvers.Load(), nil)
|
||||
r := weightedRandom(ctx, *sut.resolvers.Load(), nil)
|
||||
resolverCount[r.resolver]++
|
||||
}
|
||||
for _, v := range resolverCount {
|
||||
|
@ -517,7 +517,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
resolverCount := make(map[*UpstreamResolver]int)
|
||||
|
||||
for i := 0; i < 200; i++ {
|
||||
r := weightedRandom(*sut.resolvers.Load(), nil)
|
||||
r := weightedRandom(ctx, *sut.resolvers.Load(), nil)
|
||||
res := r.resolver.(*UpstreamResolver)
|
||||
|
||||
resolverCount[res]++
|
||||
|
|
|
@ -162,7 +162,7 @@ func (r *httpUpstreamClient) callExternal(
|
|||
}
|
||||
|
||||
defer func() {
|
||||
util.LogOnError("can't close response body ", httpResponse.Body.Close())
|
||||
util.LogOnError(ctx, "can't close response body ", httpResponse.Body.Close())
|
||||
}()
|
||||
|
||||
if httpResponse.StatusCode != http.StatusOK {
|
||||
|
|
|
@ -426,7 +426,9 @@ func (s *Server) registerDNSHandlers(ctx context.Context) {
|
|||
for _, server := range s.dnsServers {
|
||||
handler := server.Handler.(*dns.ServeMux)
|
||||
handler.HandleFunc(".", wrappedOnRequest)
|
||||
handler.HandleFunc("healthcheck.blocky", s.OnHealthCheck)
|
||||
handler.HandleFunc("healthcheck.blocky", func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
s.OnHealthCheck(ctx, w, m)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -606,10 +608,10 @@ func (s *Server) OnRequest(
|
|||
m := new(dns.Msg)
|
||||
m.SetRcode(request, dns.RcodeServerFailure)
|
||||
err := w.WriteMsg(m)
|
||||
util.LogOnError("can't write message: ", err)
|
||||
util.LogOnError(ctx, "can't write message: ", err)
|
||||
} else {
|
||||
err := w.WriteMsg(response.Res)
|
||||
util.LogOnError("can't write message: ", err)
|
||||
util.LogOnError(ctx, "can't write message: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -672,13 +674,13 @@ func getMaxResponseSize(req *model.Request) int {
|
|||
}
|
||||
|
||||
// OnHealthCheck Handler for docker health check. Just returns OK code without delegating to resolver chain
|
||||
func (s *Server) OnHealthCheck(w dns.ResponseWriter, request *dns.Msg) {
|
||||
func (s *Server) OnHealthCheck(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) {
|
||||
resp := new(dns.Msg)
|
||||
resp.SetReply(request)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
|
||||
err := w.WriteMsg(resp)
|
||||
util.LogOnError("can't write message: ", err)
|
||||
util.LogOnError(ctx, "can't write message: ", err)
|
||||
}
|
||||
|
||||
func resolveClientIPAndProtocol(addr net.Addr) (ip net.IP, protocol model.RequestProtocol) {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -154,9 +155,9 @@ func IterateValueSorted(in map[string]int, fn func(string, int)) {
|
|||
}
|
||||
|
||||
// LogOnError logs the message only if error is not nil
|
||||
func LogOnError(message string, err error) {
|
||||
func LogOnError(ctx context.Context, message string, err error) {
|
||||
if err != nil {
|
||||
log.Log().Error(message, err)
|
||||
log.FromCtx(ctx).Error(message, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -153,11 +154,11 @@ var _ = Describe("Common function tests", func() {
|
|||
Describe("Logging functions", func() {
|
||||
When("LogOnError is called with error", func() {
|
||||
err := errors.New("test")
|
||||
It("should log", func() {
|
||||
It("should log", func(ctx context.Context) {
|
||||
hook := test.NewGlobal()
|
||||
Log().AddHook(hook)
|
||||
defer hook.Reset()
|
||||
LogOnError("message ", err)
|
||||
LogOnError(ctx, "message ", err)
|
||||
Expect(hook.LastEntry().Message).Should(Equal("message test"))
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue