refactor(log): store log in context so it's automatically propagated

This commit is contained in:
ThinkChaos 2024-01-27 17:47:01 -05:00
parent d83b7432d4
commit b335887992
17 changed files with 177 additions and 102 deletions

46
log/context.go Normal file
View File

@ -0,0 +1,46 @@
package log
import (
"context"
"github.com/sirupsen/logrus"
)
type ctxKey struct{}
func NewCtx(ctx context.Context, logger *logrus.Entry) (context.Context, *logrus.Entry) {
ctx = context.WithValue(ctx, ctxKey{}, logger)
return ctx, entryWithCtx(ctx, logger)
}
func FromCtx(ctx context.Context) *logrus.Entry {
logger, ok := ctx.Value(ctxKey{}).(*logrus.Entry)
if !ok {
// Fallback to the global logger
return logrus.NewEntry(Log())
}
// Ensure `logger.Context == ctx`, not always the case since `ctx` could be a child of `logger.Context`
return entryWithCtx(ctx, logger)
}
func entryWithCtx(ctx context.Context, logger *logrus.Entry) *logrus.Entry {
loggerCopy := *logger
loggerCopy.Context = ctx
return &loggerCopy
}
func WrapCtx(ctx context.Context, wrap func(*logrus.Entry) *logrus.Entry) (context.Context, *logrus.Entry) {
logger := FromCtx(ctx)
logger = wrap(logger)
return NewCtx(ctx, logger)
}
func CtxWithFields(ctx context.Context, fields logrus.Fields) (context.Context, *logrus.Entry) {
return WrapCtx(ctx, func(e *logrus.Entry) *logrus.Entry {
return e.WithFields(fields)
})
}

View File

@ -174,18 +174,20 @@ func NewBlockingResolver(ctx context.Context,
}
func (r *BlockingResolver) redisSubscriber(ctx context.Context) {
ctx, logger := r.log(ctx)
for {
select {
case em := <-r.redisClient.EnabledChannel:
if em != nil {
r.log().Debug("Received state from redis: ", em)
logger.Debug("Received state from redis: ", em)
if em.State {
r.internalEnableBlocking()
} else {
err := r.internalDisableBlocking(ctx, em.Duration, em.Groups)
if err != nil {
r.log().Warn("Blocking couldn't be disabled:", err)
logger.Warn("Blocking couldn't be disabled:", err)
}
}
}
@ -394,7 +396,7 @@ func (r *BlockingResolver) handleBlacklist(ctx context.Context, groupsToCheck []
// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
func (r *BlockingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "blacklist_resolver")
ctx, logger := r.log(ctx)
groupsToCheck := r.groupsToCheckForClient(request)
if len(groupsToCheck) > 0 {
@ -575,7 +577,9 @@ func (b ipBlockHandler) handleBlock(question dns.Question, response *dns.Msg) {
}
func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifier string) (*[]net.IP, time.Duration) {
prefixedLog := log.WithPrefix(r.log(), "client_id_cache")
ctx, logger := r.logWith(ctx, func(logger *logrus.Entry) *logrus.Entry {
return log.WithPrefix(logger, "client_id_cache")
})
var result []net.IP
@ -584,7 +588,7 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifi
for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} {
resp, err := r.next.Resolve(ctx, &model.Request{
Req: util.NewMsgWithQuestion(identifier, dns.Type(qType)),
Log: prefixedLog,
Log: logger,
})
if err == nil && resp.Res.Rcode == dns.RcodeSuccess {
@ -598,11 +602,16 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifi
result = append(result, v.AAAA)
}
}
prefixedLog.Debugf("resolved IPs '%v' for fq identifier '%s'", result, identifier)
}
}
if len(result) != 0 {
logger.WithFields(logrus.Fields{
"ips": result,
"client_id": identifier,
}).Debug("resolved client IPs")
}
return &result, ttl
}

View File

@ -12,7 +12,6 @@ import (
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/hashicorp/go-multierror"
@ -70,13 +69,15 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
dialer: new(net.Dialer),
}
ctx, logger := b.log(ctx)
bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS, cfg.Upstreams)
if err != nil {
return nil, err
}
if len(bootstraped) == 0 {
b.log().Info("bootstrapDns is not configured, will use system resolver")
logger.Info("bootstrapDns is not configured, will use system resolver")
return b, nil
}
@ -116,10 +117,9 @@ func (b *Bootstrap) Resolve(ctx context.Context, request *model.Request) (*model
}
// Add bootstrap prefix to all inner resolver logs
req := *request
req.Log = log.WithPrefix(req.Log, b.Type())
ctx, _ = b.log(ctx)
return b.resolver.Resolve(ctx, &req)
return b.resolver.Resolve(ctx, request)
}
func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) {
@ -168,7 +168,7 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
}
func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
logger := b.log().WithFields(logrus.Fields{"network": network, "addr": addr})
ctx, logger := b.logWithFields(ctx, logrus.Fields{"network": network, "addr": addr})
host, port, err := net.SplitHostPort(addr)
if err != nil {
@ -234,9 +234,10 @@ func (b *Bootstrap) resolveType(ctx context.Context, hostname string, qType dns.
return []net.IP{ip}, nil
}
ctx, _ = b.log(ctx)
req := model.Request{
Req: util.NewMsgWithQuestion(hostname, qType),
Log: b.log(),
}
rsp, err := b.resolver.Resolve(ctx, &req)

View File

@ -10,7 +10,6 @@ import (
"github.com/0xERR0R/blocky/cache/expirationcache"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/evt"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/util"
@ -107,7 +106,7 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin
func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) {
qType, domainName := util.ExtractCacheKey(cacheKey)
logger := r.log()
ctx, logger := r.log(ctx)
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
@ -133,11 +132,13 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string)
}
func (r *CachingResolver) redisSubscriber(ctx context.Context) {
ctx, logger := r.log(ctx)
for {
select {
case rc := <-r.redisClient.CacheChannel:
if rc != nil {
r.log().Debug("Received key from redis: ", rc.Key)
logger.Debug("Received key from redis: ", rc.Key)
ttl := r.adjustTTLs(rc.Response.Res.Answer)
r.putInCache(rc.Key, rc.Response, ttl, false)
}
@ -158,7 +159,7 @@ func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
// Resolve checks if the current query should use the cache and if the result is already in
// the cache and returns it or delegates to the next resolver
func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) {
logger := log.WithPrefix(request.Log, "caching_resolver")
ctx, logger := r.log(ctx)
if !r.IsEnabled() || !isRequestCacheable(request) {
logger.Debug("skip cache")
@ -171,7 +172,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (
cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain)
logger := logger.WithField("domain", util.Obfuscate(domain))
val, ttl := r.getFromCache(cacheKey)
val, ttl := r.getFromCache(logger, cacheKey)
if val != nil {
logger.Debug("domain is cached")
@ -200,7 +201,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (
return response, err
}
func (r *CachingResolver) getFromCache(key string) (*dns.Msg, time.Duration) {
func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) (*dns.Msg, time.Duration) {
val, ttl := r.resultCache.Get(key)
if val == nil {
return nil, 0
@ -210,7 +211,7 @@ func (r *CachingResolver) getFromCache(key string) (*dns.Msg, time.Duration) {
err := res.Unpack(*val)
if err != nil {
r.log().Error("can't unpack cached entry. Cache malformed?", err)
logger.Error("can't unpack cached entry. Cache malformed?", err)
return nil, 0
}
@ -317,7 +318,9 @@ func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{})
}
}
func (r *CachingResolver) FlushCaches(context.Context) {
r.log().Debug("flush caches")
func (r *CachingResolver) FlushCaches(ctx context.Context) {
_, logger := r.log(ctx)
logger.Debug("flush caches")
r.resultCache.Clear()
}

View File

@ -63,7 +63,7 @@ func (r *ClientNamesResolver) Resolve(ctx context.Context, request *model.Reques
clientNames := r.getClientNames(ctx, request)
request.ClientNames = clientNames
request.Log = request.Log.WithField("client_names", strings.Join(clientNames, "; "))
ctx, request.Log = log.CtxWithFields(ctx, logrus.Fields{"client_names": strings.Join(clientNames, "; ")})
return r.next.Resolve(ctx, request)
}
@ -88,7 +88,7 @@ func (r *ClientNamesResolver) getClientNames(ctx context.Context, request *model
return cpy
}
names := r.resolveClientNames(ctx, ip, log.WithPrefix(request.Log, "client_names_resolver"))
names := r.resolveClientNames(ctx, ip)
r.cache.Put(ip.String(), &names, time.Hour)
@ -111,9 +111,9 @@ func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNam
}
// tries to resolve client name from mapping, performs reverse DNS lookup otherwise
func (r *ClientNamesResolver) resolveClientNames(
ctx context.Context, ip net.IP, logger *logrus.Entry,
) (result []string) {
func (r *ClientNamesResolver) resolveClientNames(ctx context.Context, ip net.IP) (result []string) {
ctx, logger := r.log(ctx)
// try client mapping first
result = r.getNameFromIPMapping(ip, result)
if len(result) > 0 {

View File

@ -6,7 +6,6 @@ import (
"strings"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
@ -83,7 +82,7 @@ func (r *ConditionalUpstreamResolver) processRequest(
// Resolve uses the conditional resolver to resolve the query
func (r *ConditionalUpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "conditional_resolver")
ctx, logger := r.log(ctx)
if len(r.mapping) > 0 {
resolved, resp, err := r.processRequest(ctx, request)
@ -101,7 +100,7 @@ func (r *ConditionalUpstreamResolver) internalResolve(ctx context.Context, reso
req *model.Request,
) (*model.Response, error) {
// internal request resolution
logger := log.WithPrefix(req.Log, "conditional_resolver")
ctx, logger := r.log(ctx)
req.Req.Question[0].Name = dns.Fqdn(doFQ)
response, err := reso.Resolve(ctx, req)

View File

@ -8,7 +8,6 @@ import (
"strings"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
@ -105,28 +104,12 @@ func (r *CustomDNSResolver) handleReverseDNS(request *model.Request) *model.Resp
return nil
}
func (r *CustomDNSResolver) forwardResponse(
logger *logrus.Entry,
ctx context.Context,
request *model.Request,
) (*model.Response, error) {
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
forwardResponse, err := r.next.Resolve(ctx, request)
if err != nil {
return nil, err
}
return forwardResponse, nil
}
func (r *CustomDNSResolver) processRequest(
ctx context.Context,
logger *logrus.Entry,
request *model.Request,
resolvedCnames []string,
) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "custom_dns_resolver")
response := new(dns.Msg)
response.SetReply(request.Req)
@ -142,7 +125,7 @@ func (r *CustomDNSResolver) processRequest(
if found {
for _, entry := range entries {
result, err := r.processDNSEntry(ctx, request, resolvedCnames, question, entry)
result, err := r.processDNSEntry(ctx, logger, request, resolvedCnames, question, entry)
if err != nil {
return nil, err
}
@ -169,18 +152,21 @@ func (r *CustomDNSResolver) processRequest(
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil
}
if i := strings.Index(domain, "."); i >= 0 {
if i := strings.IndexRune(domain, '.'); i >= 0 {
domain = domain[i+1:]
} else {
break
}
}
return r.forwardResponse(logger, ctx, request)
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(ctx, request)
}
func (r *CustomDNSResolver) processDNSEntry(
ctx context.Context,
logger *logrus.Entry,
request *model.Request,
resolvedCnames []string,
question dns.Question,
@ -192,7 +178,7 @@ func (r *CustomDNSResolver) processDNSEntry(
case *dns.AAAA:
return r.processIP(v.AAAA, question, v.Header().Ttl)
case *dns.CNAME:
return r.processCNAME(ctx, request, *v, resolvedCnames, question, v.Header().Ttl)
return r.processCNAME(ctx, logger, request, *v, resolvedCnames, question, v.Header().Ttl)
}
return nil, fmt.Errorf("unsupported customDNS RR type %T", entry)
@ -200,17 +186,14 @@ func (r *CustomDNSResolver) processDNSEntry(
// Resolve uses internal mapping to resolve the query
func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
ctx, logger := r.log(ctx)
reverseResp := r.handleReverseDNS(request)
if reverseResp != nil {
return reverseResp, nil
}
resp, err := r.processRequest(ctx, request, make([]string, 0, len(r.cfg.Mapping)))
if err != nil {
return nil, err
}
return resp, nil
return r.processRequest(ctx, logger, request, make([]string, 0, len(r.cfg.Mapping)))
}
func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question, ttl uint32) (result []dns.RR, err error) {
@ -230,6 +213,7 @@ func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question, ttl uint
func (r *CustomDNSResolver) processCNAME(
ctx context.Context,
logger *logrus.Entry,
request *model.Request,
targetCname dns.CNAME,
resolvedCnames []string,
@ -259,7 +243,7 @@ func (r *CustomDNSResolver) processCNAME(
targetRequest := newRequestWithClientID(targetWithoutDot, dns.Type(question.Qtype), clientIP, clientID)
// resolve the target recursively
targetResp, err := r.processRequest(ctx, targetRequest, cnames)
targetResp, err := r.processRequest(ctx, logger, targetRequest, cnames)
if err != nil {
return nil, err
}

View File

@ -46,7 +46,8 @@ func NewHostsFileResolver(ctx context.Context,
}
err := cfg.Loading.StartPeriodicRefresh(ctx, r.loadSources, func(err error) {
r.log().WithError(err).Errorf("could not load hosts files")
_, logger := r.log(ctx)
logger.WithError(err).Errorf("could not load hosts files")
})
if err != nil {
return nil, err
@ -114,6 +115,8 @@ func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request)
return r.next.Resolve(ctx, request)
}
ctx, logger := r.log(ctx)
reverseResp := r.handleReverseDNS(request)
if reverseResp != nil {
return reverseResp, nil
@ -124,7 +127,7 @@ func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request)
response := r.resolve(request.Req, question, domain)
if response != nil {
r.log().WithFields(logrus.Fields{
logger.WithFields(logrus.Fields{
"answer": util.AnswerToString(response.Answer),
"domain": util.Obfuscate(domain),
}).Debugf("returning hosts file entry")
@ -132,7 +135,7 @@ func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request)
return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}, nil
}
r.log().WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(ctx, request)
}
@ -157,7 +160,9 @@ func (r *HostsFileResolver) loadSources(ctx context.Context) error {
return nil
}
r.log().Debug("loading hosts files")
ctx, logger := r.log(ctx)
logger.Debug("loading hosts files")
//nolint:ineffassign,staticcheck,wastedassign // keep `ctx :=` so if we use ctx in the future, we use the correct one
consumersGrp, ctx := jobgroup.WithContext(ctx)
@ -220,7 +225,9 @@ func (r *HostsFileResolver) parseFile(
p := parsers.AllowErrors(parsers.HostsFile(reader), r.cfg.Loading.MaxErrorsPerSource)
p.OnErr(func(err error) {
r.log().Warnf("error parsing %s: %s, trying to continue", opener, err)
_, logger := r.log(ctx)
logger.Warnf("error parsing %s: %s, trying to continue", opener, err)
})
return parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error {

View File

@ -10,7 +10,6 @@ import (
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
@ -148,7 +147,7 @@ func (r *ParallelBestResolver) String() string {
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, parallelResolverType)
ctx, logger := r.log(ctx)
allResolvers := *r.resolvers.Load()

View File

@ -112,7 +112,7 @@ func (r *QueryLoggingResolver) doCleanUp() {
// Resolve logs the query, duration and the result
func (r *QueryLoggingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, queryLoggingResolverType)
ctx, logger := r.log(ctx)
start := time.Now()
@ -170,6 +170,8 @@ func (r *QueryLoggingResolver) createLogEntry(request *model.Request, response *
// write entry: if log directory is configured, write to log file
func (r *QueryLoggingResolver) writeLog(ctx context.Context) {
ctx, logger := r.log(ctx)
for {
select {
case logEntry := <-r.logChan:
@ -181,8 +183,9 @@ func (r *QueryLoggingResolver) writeLog(ctx context.Context) {
// if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.)
if len(r.logChan) > halfCap {
r.log().WithField("channel_len",
len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds())
logger.
WithField("channel_len", len(r.logChan)).
Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds())
}
case <-ctx.Done():
return

View File

@ -199,8 +199,22 @@ func (t *typed) String() string {
return t.Type()
}
func (t *typed) log() *logrus.Entry {
return log.PrefixedLog(t.Type())
func (t *typed) log(ctx context.Context) (context.Context, *logrus.Entry) {
return t.logWith(ctx, func(logger *logrus.Entry) *logrus.Entry { return logger })
}
func (t *typed) logWithFields(ctx context.Context, fields logrus.Fields) (context.Context, *logrus.Entry) {
return t.logWith(ctx, func(logger *logrus.Entry) *logrus.Entry {
return logger.WithFields(fields)
})
}
func (t *typed) logWith(ctx context.Context, wrap func(*logrus.Entry) *logrus.Entry) (context.Context, *logrus.Entry) {
return log.WrapCtx(ctx, func(logger *logrus.Entry) *logrus.Entry {
logger = log.WithPrefix(logger, t.Type())
return wrap(logger)
})
}
// Should be embedded in a Resolver to auto-implement `config.Configurable`.
@ -223,7 +237,7 @@ func (c *configurable[T]) LogConfig(logger *logrus.Entry) {
}
type initializable interface {
log() *logrus.Entry
log(context.Context) (context.Context, *logrus.Entry)
setResolvers([]*upstreamResolverStatus)
}
@ -242,7 +256,9 @@ func initGroupResolvers[T initializable](
}
onErr := func(err error) {
r.log().WithError(err).Error("upstream verification error, will continue to use bootstrap DNS")
_, logger := r.log(ctx)
logger.WithError(err).Error("upstream verification error, will continue to use bootstrap DNS")
}
err := cfg.Init.Strategy.Do(ctx, init, onErr)

View File

@ -6,7 +6,6 @@ import (
"strings"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
@ -59,7 +58,7 @@ func (r *RewriterResolver) LogConfig(logger *logrus.Entry) {
// Resolve uses the inner resolver to resolve the rewritten query
func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "rewriter_resolver")
ctx, logger := r.log(ctx)
original := request.Req

View File

@ -8,7 +8,6 @@ import (
"sync/atomic"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
@ -74,7 +73,7 @@ func (r *StrictResolver) String() string {
// Resolve sends the query request in a strict order to the upstream resolvers
func (r *StrictResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, strictResolverType)
ctx, logger := r.log(ctx)
// start with first resolver
for _, resolver := range *r.resolvers.Load() {

View File

@ -275,7 +275,9 @@ func NewUpstreamResolver(
r := newUpstreamResolverUnchecked(cfg, bootstrap)
onErr := func(err error) {
r.log().WithError(err).Warn("initial resolver test failed")
_, logger := r.log(ctx)
logger.WithError(err).Warn("initial resolver test failed")
}
err := cfg.Init.Strategy.Do(ctx, r.testResolve, onErr)
@ -307,8 +309,10 @@ func (r UpstreamResolver) Upstream() config.Upstream {
return r.cfg.Upstream
}
func (r *UpstreamResolver) log() *logrus.Entry {
return r.typed.log().WithField("upstream", r.cfg.String())
func (r *UpstreamResolver) log(ctx context.Context) (context.Context, *logrus.Entry) {
return r.logWithFields(ctx, logrus.Fields{
"upstream": r.cfg.String(),
})
}
// testResolve sends a test query to verify the upstream is reachable and working
@ -322,7 +326,9 @@ func (r *UpstreamResolver) testResolve(ctx context.Context) error {
}
// Resolve calls external resolver
func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) {
ctx, logger := r.log(ctx)
ips, err := r.bootstrap.UpstreamIPs(ctx, r)
if err != nil {
return nil, err
@ -347,7 +353,7 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request)
}
resp = response
r.logResponse(request, response, ip, rtt)
r.logResponse(logger, request, response, ip, rtt)
return nil
},
@ -358,7 +364,7 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request)
retry.LastErrorOnly(true),
retry.RetryIf(isTimeout),
retry.OnRetry(func(n uint, err error) {
r.log().WithFields(logrus.Fields{
logger.WithFields(logrus.Fields{
"upstream": r.cfg.String(),
"upstream_ip": ip.String(),
"question": util.QuestionToString(request.Req.Question),
@ -374,8 +380,10 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request)
return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.cfg)}, nil
}
func (r *UpstreamResolver) logResponse(request *model.Request, resp *dns.Msg, ip net.IP, rtt time.Duration) {
r.log().WithFields(logrus.Fields{
func (r *UpstreamResolver) logResponse(
logger *logrus.Entry, request *model.Request, resp *dns.Msg, ip net.IP, rtt time.Duration,
) {
logger.WithFields(logrus.Fields{
"answer": util.AnswerToString(resp.Answer),
"return_code": dns.RcodeToString[resp.Rcode],
"upstream": r.cfg.String(),

View File

@ -7,7 +7,6 @@ import (
"strings"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/sirupsen/logrus"
@ -106,9 +105,9 @@ func (r *UpstreamTreeResolver) String() string {
}
func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, upstreamTreeResolverType)
ctx, logger := r.log(ctx)
group := r.upstreamGroupByClient(request)
group := r.upstreamGroupByClient(logger, request)
// delegate request to group resolver
logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver")
@ -116,7 +115,7 @@ func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Reque
return r.branches[group].Resolve(ctx, request)
}
func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string {
func (r *UpstreamTreeResolver) upstreamGroupByClient(logger *logrus.Entry, request *model.Request) string {
groups := make([]string, 0, len(r.branches))
clientIP := request.ClientIP.String()
@ -145,7 +144,7 @@ func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) str
if len(groups) > 0 {
if len(groups) > 1 {
request.Log.WithFields(logrus.Fields{
logger.WithFields(logrus.Fields{
"clientNames": request.ClientNames,
"clientIP": clientIP,
"groups": groups,

View File

@ -563,19 +563,22 @@ func extractClientIDFromHost(hostName string) string {
}
func newRequest(
ctx context.Context,
clientIP net.IP, clientID string,
protocol model.RequestProtocol, request *dns.Msg,
) *model.Request {
return &model.Request{
) (context.Context, *model.Request) {
ctx, logger := log.CtxWithFields(ctx, logrus.Fields{
"question": util.QuestionToString(request.Question),
"client_ip": clientIP,
})
return ctx, &model.Request{
ClientIP: clientIP,
RequestClientID: clientID,
Protocol: protocol,
Req: request,
Log: log.Log().WithFields(logrus.Fields{
"question": util.QuestionToString(request.Question),
"client_ip": clientIP,
}),
RequestTS: time.Now(),
Log: logger,
RequestTS: time.Now(),
}
}
@ -594,7 +597,7 @@ func (s *Server) OnRequest(
hostName = con.ConnectionState().ServerName
}
req := newRequest(clientIP, extractClientIDFromHost(hostName), protocol, request)
ctx, req := newRequest(ctx, clientIP, extractClientIDFromHost(hostName), protocol, request)
response, err := s.resolve(ctx, req)
if err != nil {

View File

@ -147,9 +147,9 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
clientID = extractClientIDFromHost(req.Host)
}
r := newRequest(util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg)
ctx, r := newRequest(req.Context(), util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg)
resResponse, err := s.resolve(req.Context(), r)
resResponse, err := s.resolve(ctx, r)
if err != nil {
logAndResponseWithError(err, "unable to process query: ", rw)
@ -173,7 +173,7 @@ func (s *Server) Query(
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
) (*model.Response, error) {
dnsRequest := util.NewMsgWithQuestion(question, qType)
r := newRequest(clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest)
ctx, r := newRequest(ctx, clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest)
return s.resolve(ctx, r)
}