From b335887992d36d491db6619e22bf7296312f8401 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Sat, 27 Jan 2024 17:47:01 -0500 Subject: [PATCH] refactor(log): store log in context so it's automatically propagated --- log/context.go | 46 +++++++++++++++++++++++ resolver/blocking_resolver.go | 23 ++++++++---- resolver/bootstrap.go | 15 ++++---- resolver/caching_resolver.go | 21 ++++++----- resolver/client_names_resolver.go | 10 ++--- resolver/conditional_upstream_resolver.go | 5 +-- resolver/custom_dns_resolver.go | 42 +++++++-------------- resolver/hosts_file_resolver.go | 17 ++++++--- resolver/parallel_best_resolver.go | 3 +- resolver/query_logging_resolver.go | 9 +++-- resolver/resolver.go | 24 ++++++++++-- resolver/rewriter_resolver.go | 3 +- resolver/strict_resolver.go | 3 +- resolver/upstream_resolver.go | 24 ++++++++---- resolver/upstream_tree_resolver.go | 9 ++--- server/server.go | 19 ++++++---- server/server_endpoints.go | 6 +-- 17 files changed, 177 insertions(+), 102 deletions(-) create mode 100644 log/context.go diff --git a/log/context.go b/log/context.go new file mode 100644 index 00000000..ef504698 --- /dev/null +++ b/log/context.go @@ -0,0 +1,46 @@ +package log + +import ( + "context" + + "github.com/sirupsen/logrus" +) + +type ctxKey struct{} + +func NewCtx(ctx context.Context, logger *logrus.Entry) (context.Context, *logrus.Entry) { + ctx = context.WithValue(ctx, ctxKey{}, logger) + + return ctx, entryWithCtx(ctx, logger) +} + +func FromCtx(ctx context.Context) *logrus.Entry { + logger, ok := ctx.Value(ctxKey{}).(*logrus.Entry) + if !ok { + // Fallback to the global logger + return logrus.NewEntry(Log()) + } + + // Ensure `logger.Context == ctx`, not always the case since `ctx` could be a child of `logger.Context` + return entryWithCtx(ctx, logger) +} + +func entryWithCtx(ctx context.Context, logger *logrus.Entry) *logrus.Entry { + loggerCopy := *logger + loggerCopy.Context = ctx + + return &loggerCopy +} + +func WrapCtx(ctx context.Context, wrap func(*logrus.Entry) *logrus.Entry) (context.Context, *logrus.Entry) { + logger := FromCtx(ctx) + logger = wrap(logger) + + return NewCtx(ctx, logger) +} + +func CtxWithFields(ctx context.Context, fields logrus.Fields) (context.Context, *logrus.Entry) { + return WrapCtx(ctx, func(e *logrus.Entry) *logrus.Entry { + return e.WithFields(fields) + }) +} diff --git a/resolver/blocking_resolver.go b/resolver/blocking_resolver.go index 662f3514..7e250efe 100644 --- a/resolver/blocking_resolver.go +++ b/resolver/blocking_resolver.go @@ -174,18 +174,20 @@ func NewBlockingResolver(ctx context.Context, } func (r *BlockingResolver) redisSubscriber(ctx context.Context) { + ctx, logger := r.log(ctx) + for { select { case em := <-r.redisClient.EnabledChannel: if em != nil { - r.log().Debug("Received state from redis: ", em) + logger.Debug("Received state from redis: ", em) if em.State { r.internalEnableBlocking() } else { err := r.internalDisableBlocking(ctx, em.Duration, em.Groups) if err != nil { - r.log().Warn("Blocking couldn't be disabled:", err) + logger.Warn("Blocking couldn't be disabled:", err) } } } @@ -394,7 +396,7 @@ func (r *BlockingResolver) handleBlacklist(ctx context.Context, groupsToCheck [] // Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked func (r *BlockingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { - logger := log.WithPrefix(request.Log, "blacklist_resolver") + ctx, logger := r.log(ctx) groupsToCheck := r.groupsToCheckForClient(request) if len(groupsToCheck) > 0 { @@ -575,7 +577,9 @@ func (b ipBlockHandler) handleBlock(question dns.Question, response *dns.Msg) { } func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifier string) (*[]net.IP, time.Duration) { - prefixedLog := log.WithPrefix(r.log(), "client_id_cache") + ctx, logger := r.logWith(ctx, func(logger *logrus.Entry) *logrus.Entry { + return log.WithPrefix(logger, "client_id_cache") + }) var result []net.IP @@ -584,7 +588,7 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifi for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} { resp, err := r.next.Resolve(ctx, &model.Request{ Req: util.NewMsgWithQuestion(identifier, dns.Type(qType)), - Log: prefixedLog, + Log: logger, }) if err == nil && resp.Res.Rcode == dns.RcodeSuccess { @@ -598,11 +602,16 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifi result = append(result, v.AAAA) } } - - prefixedLog.Debugf("resolved IPs '%v' for fq identifier '%s'", result, identifier) } } + if len(result) != 0 { + logger.WithFields(logrus.Fields{ + "ips": result, + "client_id": identifier, + }).Debug("resolved client IPs") + } + return &result, ttl } diff --git a/resolver/bootstrap.go b/resolver/bootstrap.go index 0e6ba4b1..ba88e067 100644 --- a/resolver/bootstrap.go +++ b/resolver/bootstrap.go @@ -12,7 +12,6 @@ import ( "time" "github.com/0xERR0R/blocky/config" - "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" "github.com/hashicorp/go-multierror" @@ -70,13 +69,15 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er dialer: new(net.Dialer), } + ctx, logger := b.log(ctx) + bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS, cfg.Upstreams) if err != nil { return nil, err } if len(bootstraped) == 0 { - b.log().Info("bootstrapDns is not configured, will use system resolver") + logger.Info("bootstrapDns is not configured, will use system resolver") return b, nil } @@ -116,10 +117,9 @@ func (b *Bootstrap) Resolve(ctx context.Context, request *model.Request) (*model } // Add bootstrap prefix to all inner resolver logs - req := *request - req.Log = log.WithPrefix(req.Log, b.Type()) + ctx, _ = b.log(ctx) - return b.resolver.Resolve(ctx, &req) + return b.resolver.Resolve(ctx, request) } func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) { @@ -168,7 +168,7 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport { } func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) { - logger := b.log().WithFields(logrus.Fields{"network": network, "addr": addr}) + ctx, logger := b.logWithFields(ctx, logrus.Fields{"network": network, "addr": addr}) host, port, err := net.SplitHostPort(addr) if err != nil { @@ -234,9 +234,10 @@ func (b *Bootstrap) resolveType(ctx context.Context, hostname string, qType dns. return []net.IP{ip}, nil } + ctx, _ = b.log(ctx) + req := model.Request{ Req: util.NewMsgWithQuestion(hostname, qType), - Log: b.log(), } rsp, err := b.resolver.Resolve(ctx, &req) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index ed056c17..51579890 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -10,7 +10,6 @@ import ( "github.com/0xERR0R/blocky/cache/expirationcache" "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/evt" - "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/redis" "github.com/0xERR0R/blocky/util" @@ -107,7 +106,7 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) { qType, domainName := util.ExtractCacheKey(cacheKey) - logger := r.log() + ctx, logger := r.log(ctx) logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType) @@ -133,11 +132,13 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) } func (r *CachingResolver) redisSubscriber(ctx context.Context) { + ctx, logger := r.log(ctx) + for { select { case rc := <-r.redisClient.CacheChannel: if rc != nil { - r.log().Debug("Received key from redis: ", rc.Key) + logger.Debug("Received key from redis: ", rc.Key) ttl := r.adjustTTLs(rc.Response.Res.Answer) r.putInCache(rc.Key, rc.Response, ttl, false) } @@ -158,7 +159,7 @@ func (r *CachingResolver) LogConfig(logger *logrus.Entry) { // Resolve checks if the current query should use the cache and if the result is already in // the cache and returns it or delegates to the next resolver func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) { - logger := log.WithPrefix(request.Log, "caching_resolver") + ctx, logger := r.log(ctx) if !r.IsEnabled() || !isRequestCacheable(request) { logger.Debug("skip cache") @@ -171,7 +172,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) ( cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain) logger := logger.WithField("domain", util.Obfuscate(domain)) - val, ttl := r.getFromCache(cacheKey) + val, ttl := r.getFromCache(logger, cacheKey) if val != nil { logger.Debug("domain is cached") @@ -200,7 +201,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) ( return response, err } -func (r *CachingResolver) getFromCache(key string) (*dns.Msg, time.Duration) { +func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) (*dns.Msg, time.Duration) { val, ttl := r.resultCache.Get(key) if val == nil { return nil, 0 @@ -210,7 +211,7 @@ func (r *CachingResolver) getFromCache(key string) (*dns.Msg, time.Duration) { err := res.Unpack(*val) if err != nil { - r.log().Error("can't unpack cached entry. Cache malformed?", err) + logger.Error("can't unpack cached entry. Cache malformed?", err) return nil, 0 } @@ -317,7 +318,9 @@ func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{}) } } -func (r *CachingResolver) FlushCaches(context.Context) { - r.log().Debug("flush caches") +func (r *CachingResolver) FlushCaches(ctx context.Context) { + _, logger := r.log(ctx) + + logger.Debug("flush caches") r.resultCache.Clear() } diff --git a/resolver/client_names_resolver.go b/resolver/client_names_resolver.go index d022ea08..60aecd68 100644 --- a/resolver/client_names_resolver.go +++ b/resolver/client_names_resolver.go @@ -63,7 +63,7 @@ func (r *ClientNamesResolver) Resolve(ctx context.Context, request *model.Reques clientNames := r.getClientNames(ctx, request) request.ClientNames = clientNames - request.Log = request.Log.WithField("client_names", strings.Join(clientNames, "; ")) + ctx, request.Log = log.CtxWithFields(ctx, logrus.Fields{"client_names": strings.Join(clientNames, "; ")}) return r.next.Resolve(ctx, request) } @@ -88,7 +88,7 @@ func (r *ClientNamesResolver) getClientNames(ctx context.Context, request *model return cpy } - names := r.resolveClientNames(ctx, ip, log.WithPrefix(request.Log, "client_names_resolver")) + names := r.resolveClientNames(ctx, ip) r.cache.Put(ip.String(), &names, time.Hour) @@ -111,9 +111,9 @@ func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNam } // tries to resolve client name from mapping, performs reverse DNS lookup otherwise -func (r *ClientNamesResolver) resolveClientNames( - ctx context.Context, ip net.IP, logger *logrus.Entry, -) (result []string) { +func (r *ClientNamesResolver) resolveClientNames(ctx context.Context, ip net.IP) (result []string) { + ctx, logger := r.log(ctx) + // try client mapping first result = r.getNameFromIPMapping(ip, result) if len(result) > 0 { diff --git a/resolver/conditional_upstream_resolver.go b/resolver/conditional_upstream_resolver.go index b51d6d98..d69e444e 100644 --- a/resolver/conditional_upstream_resolver.go +++ b/resolver/conditional_upstream_resolver.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/0xERR0R/blocky/config" - "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" @@ -83,7 +82,7 @@ func (r *ConditionalUpstreamResolver) processRequest( // Resolve uses the conditional resolver to resolve the query func (r *ConditionalUpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { - logger := log.WithPrefix(request.Log, "conditional_resolver") + ctx, logger := r.log(ctx) if len(r.mapping) > 0 { resolved, resp, err := r.processRequest(ctx, request) @@ -101,7 +100,7 @@ func (r *ConditionalUpstreamResolver) internalResolve(ctx context.Context, reso req *model.Request, ) (*model.Response, error) { // internal request resolution - logger := log.WithPrefix(req.Log, "conditional_resolver") + ctx, logger := r.log(ctx) req.Req.Question[0].Name = dns.Fqdn(doFQ) response, err := reso.Resolve(ctx, req) diff --git a/resolver/custom_dns_resolver.go b/resolver/custom_dns_resolver.go index 602a24fd..bf160a29 100644 --- a/resolver/custom_dns_resolver.go +++ b/resolver/custom_dns_resolver.go @@ -8,7 +8,6 @@ import ( "strings" "github.com/0xERR0R/blocky/config" - "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" @@ -105,28 +104,12 @@ func (r *CustomDNSResolver) handleReverseDNS(request *model.Request) *model.Resp return nil } -func (r *CustomDNSResolver) forwardResponse( - logger *logrus.Entry, - ctx context.Context, - request *model.Request, -) (*model.Response, error) { - logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver") - - forwardResponse, err := r.next.Resolve(ctx, request) - if err != nil { - return nil, err - } - - return forwardResponse, nil -} - func (r *CustomDNSResolver) processRequest( ctx context.Context, + logger *logrus.Entry, request *model.Request, resolvedCnames []string, ) (*model.Response, error) { - logger := log.WithPrefix(request.Log, "custom_dns_resolver") - response := new(dns.Msg) response.SetReply(request.Req) @@ -142,7 +125,7 @@ func (r *CustomDNSResolver) processRequest( if found { for _, entry := range entries { - result, err := r.processDNSEntry(ctx, request, resolvedCnames, question, entry) + result, err := r.processDNSEntry(ctx, logger, request, resolvedCnames, question, entry) if err != nil { return nil, err } @@ -169,18 +152,21 @@ func (r *CustomDNSResolver) processRequest( return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil } - if i := strings.Index(domain, "."); i >= 0 { + if i := strings.IndexRune(domain, '.'); i >= 0 { domain = domain[i+1:] } else { break } } - return r.forwardResponse(logger, ctx, request) + logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver") + + return r.next.Resolve(ctx, request) } func (r *CustomDNSResolver) processDNSEntry( ctx context.Context, + logger *logrus.Entry, request *model.Request, resolvedCnames []string, question dns.Question, @@ -192,7 +178,7 @@ func (r *CustomDNSResolver) processDNSEntry( case *dns.AAAA: return r.processIP(v.AAAA, question, v.Header().Ttl) case *dns.CNAME: - return r.processCNAME(ctx, request, *v, resolvedCnames, question, v.Header().Ttl) + return r.processCNAME(ctx, logger, request, *v, resolvedCnames, question, v.Header().Ttl) } return nil, fmt.Errorf("unsupported customDNS RR type %T", entry) @@ -200,17 +186,14 @@ func (r *CustomDNSResolver) processDNSEntry( // Resolve uses internal mapping to resolve the query func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { + ctx, logger := r.log(ctx) + reverseResp := r.handleReverseDNS(request) if reverseResp != nil { return reverseResp, nil } - resp, err := r.processRequest(ctx, request, make([]string, 0, len(r.cfg.Mapping))) - if err != nil { - return nil, err - } - - return resp, nil + return r.processRequest(ctx, logger, request, make([]string, 0, len(r.cfg.Mapping))) } func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question, ttl uint32) (result []dns.RR, err error) { @@ -230,6 +213,7 @@ func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question, ttl uint func (r *CustomDNSResolver) processCNAME( ctx context.Context, + logger *logrus.Entry, request *model.Request, targetCname dns.CNAME, resolvedCnames []string, @@ -259,7 +243,7 @@ func (r *CustomDNSResolver) processCNAME( targetRequest := newRequestWithClientID(targetWithoutDot, dns.Type(question.Qtype), clientIP, clientID) // resolve the target recursively - targetResp, err := r.processRequest(ctx, targetRequest, cnames) + targetResp, err := r.processRequest(ctx, logger, targetRequest, cnames) if err != nil { return nil, err } diff --git a/resolver/hosts_file_resolver.go b/resolver/hosts_file_resolver.go index 1dec7eb0..d2c7f1f0 100644 --- a/resolver/hosts_file_resolver.go +++ b/resolver/hosts_file_resolver.go @@ -46,7 +46,8 @@ func NewHostsFileResolver(ctx context.Context, } err := cfg.Loading.StartPeriodicRefresh(ctx, r.loadSources, func(err error) { - r.log().WithError(err).Errorf("could not load hosts files") + _, logger := r.log(ctx) + logger.WithError(err).Errorf("could not load hosts files") }) if err != nil { return nil, err @@ -114,6 +115,8 @@ func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request) return r.next.Resolve(ctx, request) } + ctx, logger := r.log(ctx) + reverseResp := r.handleReverseDNS(request) if reverseResp != nil { return reverseResp, nil @@ -124,7 +127,7 @@ func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request) response := r.resolve(request.Req, question, domain) if response != nil { - r.log().WithFields(logrus.Fields{ + logger.WithFields(logrus.Fields{ "answer": util.AnswerToString(response.Answer), "domain": util.Obfuscate(domain), }).Debugf("returning hosts file entry") @@ -132,7 +135,7 @@ func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request) return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}, nil } - r.log().WithField("next_resolver", Name(r.next)).Trace("go to next resolver") + logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver") return r.next.Resolve(ctx, request) } @@ -157,7 +160,9 @@ func (r *HostsFileResolver) loadSources(ctx context.Context) error { return nil } - r.log().Debug("loading hosts files") + ctx, logger := r.log(ctx) + + logger.Debug("loading hosts files") //nolint:ineffassign,staticcheck,wastedassign // keep `ctx :=` so if we use ctx in the future, we use the correct one consumersGrp, ctx := jobgroup.WithContext(ctx) @@ -220,7 +225,9 @@ func (r *HostsFileResolver) parseFile( p := parsers.AllowErrors(parsers.HostsFile(reader), r.cfg.Loading.MaxErrorsPerSource) p.OnErr(func(err error) { - r.log().Warnf("error parsing %s: %s, trying to continue", opener, err) + _, logger := r.log(ctx) + + logger.Warnf("error parsing %s: %s, trying to continue", opener, err) }) return parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error { diff --git a/resolver/parallel_best_resolver.go b/resolver/parallel_best_resolver.go index 31c9c399..e8dea8ed 100644 --- a/resolver/parallel_best_resolver.go +++ b/resolver/parallel_best_resolver.go @@ -10,7 +10,6 @@ import ( "time" "github.com/0xERR0R/blocky/config" - "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" @@ -148,7 +147,7 @@ func (r *ParallelBestResolver) String() string { // Resolve sends the query request to multiple upstream resolvers and returns the fastest result func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { - logger := log.WithPrefix(request.Log, parallelResolverType) + ctx, logger := r.log(ctx) allResolvers := *r.resolvers.Load() diff --git a/resolver/query_logging_resolver.go b/resolver/query_logging_resolver.go index 3a229df7..ea549327 100644 --- a/resolver/query_logging_resolver.go +++ b/resolver/query_logging_resolver.go @@ -112,7 +112,7 @@ func (r *QueryLoggingResolver) doCleanUp() { // Resolve logs the query, duration and the result func (r *QueryLoggingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { - logger := log.WithPrefix(request.Log, queryLoggingResolverType) + ctx, logger := r.log(ctx) start := time.Now() @@ -170,6 +170,8 @@ func (r *QueryLoggingResolver) createLogEntry(request *model.Request, response * // write entry: if log directory is configured, write to log file func (r *QueryLoggingResolver) writeLog(ctx context.Context) { + ctx, logger := r.log(ctx) + for { select { case logEntry := <-r.logChan: @@ -181,8 +183,9 @@ func (r *QueryLoggingResolver) writeLog(ctx context.Context) { // if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.) if len(r.logChan) > halfCap { - r.log().WithField("channel_len", - len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds()) + logger. + WithField("channel_len", len(r.logChan)). + Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds()) } case <-ctx.Done(): return diff --git a/resolver/resolver.go b/resolver/resolver.go index 355b76eb..1fe5f8c9 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -199,8 +199,22 @@ func (t *typed) String() string { return t.Type() } -func (t *typed) log() *logrus.Entry { - return log.PrefixedLog(t.Type()) +func (t *typed) log(ctx context.Context) (context.Context, *logrus.Entry) { + return t.logWith(ctx, func(logger *logrus.Entry) *logrus.Entry { return logger }) +} + +func (t *typed) logWithFields(ctx context.Context, fields logrus.Fields) (context.Context, *logrus.Entry) { + return t.logWith(ctx, func(logger *logrus.Entry) *logrus.Entry { + return logger.WithFields(fields) + }) +} + +func (t *typed) logWith(ctx context.Context, wrap func(*logrus.Entry) *logrus.Entry) (context.Context, *logrus.Entry) { + return log.WrapCtx(ctx, func(logger *logrus.Entry) *logrus.Entry { + logger = log.WithPrefix(logger, t.Type()) + + return wrap(logger) + }) } // Should be embedded in a Resolver to auto-implement `config.Configurable`. @@ -223,7 +237,7 @@ func (c *configurable[T]) LogConfig(logger *logrus.Entry) { } type initializable interface { - log() *logrus.Entry + log(context.Context) (context.Context, *logrus.Entry) setResolvers([]*upstreamResolverStatus) } @@ -242,7 +256,9 @@ func initGroupResolvers[T initializable]( } onErr := func(err error) { - r.log().WithError(err).Error("upstream verification error, will continue to use bootstrap DNS") + _, logger := r.log(ctx) + + logger.WithError(err).Error("upstream verification error, will continue to use bootstrap DNS") } err := cfg.Init.Strategy.Do(ctx, init, onErr) diff --git a/resolver/rewriter_resolver.go b/resolver/rewriter_resolver.go index bf5304fe..f5171585 100644 --- a/resolver/rewriter_resolver.go +++ b/resolver/rewriter_resolver.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/0xERR0R/blocky/config" - "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" @@ -59,7 +58,7 @@ func (r *RewriterResolver) LogConfig(logger *logrus.Entry) { // Resolve uses the inner resolver to resolve the rewritten query func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { - logger := log.WithPrefix(request.Log, "rewriter_resolver") + ctx, logger := r.log(ctx) original := request.Req diff --git a/resolver/strict_resolver.go b/resolver/strict_resolver.go index 0c80da3f..2c6d3a57 100644 --- a/resolver/strict_resolver.go +++ b/resolver/strict_resolver.go @@ -8,7 +8,6 @@ import ( "sync/atomic" "github.com/0xERR0R/blocky/config" - "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" @@ -74,7 +73,7 @@ func (r *StrictResolver) String() string { // Resolve sends the query request in a strict order to the upstream resolvers func (r *StrictResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { - logger := log.WithPrefix(request.Log, strictResolverType) + ctx, logger := r.log(ctx) // start with first resolver for _, resolver := range *r.resolvers.Load() { diff --git a/resolver/upstream_resolver.go b/resolver/upstream_resolver.go index 2a194061..f4759ae0 100644 --- a/resolver/upstream_resolver.go +++ b/resolver/upstream_resolver.go @@ -275,7 +275,9 @@ func NewUpstreamResolver( r := newUpstreamResolverUnchecked(cfg, bootstrap) onErr := func(err error) { - r.log().WithError(err).Warn("initial resolver test failed") + _, logger := r.log(ctx) + + logger.WithError(err).Warn("initial resolver test failed") } err := cfg.Init.Strategy.Do(ctx, r.testResolve, onErr) @@ -307,8 +309,10 @@ func (r UpstreamResolver) Upstream() config.Upstream { return r.cfg.Upstream } -func (r *UpstreamResolver) log() *logrus.Entry { - return r.typed.log().WithField("upstream", r.cfg.String()) +func (r *UpstreamResolver) log(ctx context.Context) (context.Context, *logrus.Entry) { + return r.logWithFields(ctx, logrus.Fields{ + "upstream": r.cfg.String(), + }) } // testResolve sends a test query to verify the upstream is reachable and working @@ -322,7 +326,9 @@ func (r *UpstreamResolver) testResolve(ctx context.Context) error { } // Resolve calls external resolver -func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { +func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) { + ctx, logger := r.log(ctx) + ips, err := r.bootstrap.UpstreamIPs(ctx, r) if err != nil { return nil, err @@ -347,7 +353,7 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) } resp = response - r.logResponse(request, response, ip, rtt) + r.logResponse(logger, request, response, ip, rtt) return nil }, @@ -358,7 +364,7 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) retry.LastErrorOnly(true), retry.RetryIf(isTimeout), retry.OnRetry(func(n uint, err error) { - r.log().WithFields(logrus.Fields{ + logger.WithFields(logrus.Fields{ "upstream": r.cfg.String(), "upstream_ip": ip.String(), "question": util.QuestionToString(request.Req.Question), @@ -374,8 +380,10 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.cfg)}, nil } -func (r *UpstreamResolver) logResponse(request *model.Request, resp *dns.Msg, ip net.IP, rtt time.Duration) { - r.log().WithFields(logrus.Fields{ +func (r *UpstreamResolver) logResponse( + logger *logrus.Entry, request *model.Request, resp *dns.Msg, ip net.IP, rtt time.Duration, +) { + logger.WithFields(logrus.Fields{ "answer": util.AnswerToString(resp.Answer), "return_code": dns.RcodeToString[resp.Rcode], "upstream": r.cfg.String(), diff --git a/resolver/upstream_tree_resolver.go b/resolver/upstream_tree_resolver.go index 92939586..f7245484 100644 --- a/resolver/upstream_tree_resolver.go +++ b/resolver/upstream_tree_resolver.go @@ -7,7 +7,6 @@ import ( "strings" "github.com/0xERR0R/blocky/config" - "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" "github.com/sirupsen/logrus" @@ -106,9 +105,9 @@ func (r *UpstreamTreeResolver) String() string { } func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { - logger := log.WithPrefix(request.Log, upstreamTreeResolverType) + ctx, logger := r.log(ctx) - group := r.upstreamGroupByClient(request) + group := r.upstreamGroupByClient(logger, request) // delegate request to group resolver logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver") @@ -116,7 +115,7 @@ func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Reque return r.branches[group].Resolve(ctx, request) } -func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string { +func (r *UpstreamTreeResolver) upstreamGroupByClient(logger *logrus.Entry, request *model.Request) string { groups := make([]string, 0, len(r.branches)) clientIP := request.ClientIP.String() @@ -145,7 +144,7 @@ func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) str if len(groups) > 0 { if len(groups) > 1 { - request.Log.WithFields(logrus.Fields{ + logger.WithFields(logrus.Fields{ "clientNames": request.ClientNames, "clientIP": clientIP, "groups": groups, diff --git a/server/server.go b/server/server.go index 3780822b..fdccec89 100644 --- a/server/server.go +++ b/server/server.go @@ -563,19 +563,22 @@ func extractClientIDFromHost(hostName string) string { } func newRequest( + ctx context.Context, clientIP net.IP, clientID string, protocol model.RequestProtocol, request *dns.Msg, -) *model.Request { - return &model.Request{ +) (context.Context, *model.Request) { + ctx, logger := log.CtxWithFields(ctx, logrus.Fields{ + "question": util.QuestionToString(request.Question), + "client_ip": clientIP, + }) + + return ctx, &model.Request{ ClientIP: clientIP, RequestClientID: clientID, Protocol: protocol, Req: request, - Log: log.Log().WithFields(logrus.Fields{ - "question": util.QuestionToString(request.Question), - "client_ip": clientIP, - }), - RequestTS: time.Now(), + Log: logger, + RequestTS: time.Now(), } } @@ -594,7 +597,7 @@ func (s *Server) OnRequest( hostName = con.ConnectionState().ServerName } - req := newRequest(clientIP, extractClientIDFromHost(hostName), protocol, request) + ctx, req := newRequest(ctx, clientIP, extractClientIDFromHost(hostName), protocol, request) response, err := s.resolve(ctx, req) if err != nil { diff --git a/server/server_endpoints.go b/server/server_endpoints.go index 55cc06f9..0151bf60 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -147,9 +147,9 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h clientID = extractClientIDFromHost(req.Host) } - r := newRequest(util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg) + ctx, r := newRequest(req.Context(), util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg) - resResponse, err := s.resolve(req.Context(), r) + resResponse, err := s.resolve(ctx, r) if err != nil { logAndResponseWithError(err, "unable to process query: ", rw) @@ -173,7 +173,7 @@ func (s *Server) Query( ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type, ) (*model.Response, error) { dnsRequest := util.NewMsgWithQuestion(question, qType) - r := newRequest(clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest) + ctx, r := newRequest(ctx, clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest) return s.resolve(ctx, r) }