refactor(log): store log in context so it's automatically propagated

This commit is contained in:
ThinkChaos 2024-01-27 17:47:01 -05:00
parent d83b7432d4
commit b335887992
17 changed files with 177 additions and 102 deletions

46
log/context.go Normal file
View File

@ -0,0 +1,46 @@
package log
import (
"context"
"github.com/sirupsen/logrus"
)
type ctxKey struct{}
func NewCtx(ctx context.Context, logger *logrus.Entry) (context.Context, *logrus.Entry) {
ctx = context.WithValue(ctx, ctxKey{}, logger)
return ctx, entryWithCtx(ctx, logger)
}
func FromCtx(ctx context.Context) *logrus.Entry {
logger, ok := ctx.Value(ctxKey{}).(*logrus.Entry)
if !ok {
// Fallback to the global logger
return logrus.NewEntry(Log())
}
// Ensure `logger.Context == ctx`, not always the case since `ctx` could be a child of `logger.Context`
return entryWithCtx(ctx, logger)
}
func entryWithCtx(ctx context.Context, logger *logrus.Entry) *logrus.Entry {
loggerCopy := *logger
loggerCopy.Context = ctx
return &loggerCopy
}
func WrapCtx(ctx context.Context, wrap func(*logrus.Entry) *logrus.Entry) (context.Context, *logrus.Entry) {
logger := FromCtx(ctx)
logger = wrap(logger)
return NewCtx(ctx, logger)
}
func CtxWithFields(ctx context.Context, fields logrus.Fields) (context.Context, *logrus.Entry) {
return WrapCtx(ctx, func(e *logrus.Entry) *logrus.Entry {
return e.WithFields(fields)
})
}

View File

@ -174,18 +174,20 @@ func NewBlockingResolver(ctx context.Context,
} }
func (r *BlockingResolver) redisSubscriber(ctx context.Context) { func (r *BlockingResolver) redisSubscriber(ctx context.Context) {
ctx, logger := r.log(ctx)
for { for {
select { select {
case em := <-r.redisClient.EnabledChannel: case em := <-r.redisClient.EnabledChannel:
if em != nil { if em != nil {
r.log().Debug("Received state from redis: ", em) logger.Debug("Received state from redis: ", em)
if em.State { if em.State {
r.internalEnableBlocking() r.internalEnableBlocking()
} else { } else {
err := r.internalDisableBlocking(ctx, em.Duration, em.Groups) err := r.internalDisableBlocking(ctx, em.Duration, em.Groups)
if err != nil { if err != nil {
r.log().Warn("Blocking couldn't be disabled:", err) logger.Warn("Blocking couldn't be disabled:", err)
} }
} }
} }
@ -394,7 +396,7 @@ func (r *BlockingResolver) handleBlacklist(ctx context.Context, groupsToCheck []
// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked // Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
func (r *BlockingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { func (r *BlockingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "blacklist_resolver") ctx, logger := r.log(ctx)
groupsToCheck := r.groupsToCheckForClient(request) groupsToCheck := r.groupsToCheckForClient(request)
if len(groupsToCheck) > 0 { if len(groupsToCheck) > 0 {
@ -575,7 +577,9 @@ func (b ipBlockHandler) handleBlock(question dns.Question, response *dns.Msg) {
} }
func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifier string) (*[]net.IP, time.Duration) { func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifier string) (*[]net.IP, time.Duration) {
prefixedLog := log.WithPrefix(r.log(), "client_id_cache") ctx, logger := r.logWith(ctx, func(logger *logrus.Entry) *logrus.Entry {
return log.WithPrefix(logger, "client_id_cache")
})
var result []net.IP var result []net.IP
@ -584,7 +588,7 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifi
for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} { for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} {
resp, err := r.next.Resolve(ctx, &model.Request{ resp, err := r.next.Resolve(ctx, &model.Request{
Req: util.NewMsgWithQuestion(identifier, dns.Type(qType)), Req: util.NewMsgWithQuestion(identifier, dns.Type(qType)),
Log: prefixedLog, Log: logger,
}) })
if err == nil && resp.Res.Rcode == dns.RcodeSuccess { if err == nil && resp.Res.Rcode == dns.RcodeSuccess {
@ -598,11 +602,16 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifi
result = append(result, v.AAAA) result = append(result, v.AAAA)
} }
} }
prefixedLog.Debugf("resolved IPs '%v' for fq identifier '%s'", result, identifier)
} }
} }
if len(result) != 0 {
logger.WithFields(logrus.Fields{
"ips": result,
"client_id": identifier,
}).Debug("resolved client IPs")
}
return &result, ttl return &result, ttl
} }

View File

@ -12,7 +12,6 @@ import (
"time" "time"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util" "github.com/0xERR0R/blocky/util"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
@ -70,13 +69,15 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
dialer: new(net.Dialer), dialer: new(net.Dialer),
} }
ctx, logger := b.log(ctx)
bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS, cfg.Upstreams) bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS, cfg.Upstreams)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(bootstraped) == 0 { if len(bootstraped) == 0 {
b.log().Info("bootstrapDns is not configured, will use system resolver") logger.Info("bootstrapDns is not configured, will use system resolver")
return b, nil return b, nil
} }
@ -116,10 +117,9 @@ func (b *Bootstrap) Resolve(ctx context.Context, request *model.Request) (*model
} }
// Add bootstrap prefix to all inner resolver logs // Add bootstrap prefix to all inner resolver logs
req := *request ctx, _ = b.log(ctx)
req.Log = log.WithPrefix(req.Log, b.Type())
return b.resolver.Resolve(ctx, &req) return b.resolver.Resolve(ctx, request)
} }
func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) { func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) {
@ -168,7 +168,7 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
} }
func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) { func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
logger := b.log().WithFields(logrus.Fields{"network": network, "addr": addr}) ctx, logger := b.logWithFields(ctx, logrus.Fields{"network": network, "addr": addr})
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
@ -234,9 +234,10 @@ func (b *Bootstrap) resolveType(ctx context.Context, hostname string, qType dns.
return []net.IP{ip}, nil return []net.IP{ip}, nil
} }
ctx, _ = b.log(ctx)
req := model.Request{ req := model.Request{
Req: util.NewMsgWithQuestion(hostname, qType), Req: util.NewMsgWithQuestion(hostname, qType),
Log: b.log(),
} }
rsp, err := b.resolver.Resolve(ctx, &req) rsp, err := b.resolver.Resolve(ctx, &req)

View File

@ -10,7 +10,6 @@ import (
"github.com/0xERR0R/blocky/cache/expirationcache" "github.com/0xERR0R/blocky/cache/expirationcache"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/evt" "github.com/0xERR0R/blocky/evt"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/redis" "github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/util" "github.com/0xERR0R/blocky/util"
@ -107,7 +106,7 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin
func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) { func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) {
qType, domainName := util.ExtractCacheKey(cacheKey) qType, domainName := util.ExtractCacheKey(cacheKey)
logger := r.log() ctx, logger := r.log(ctx)
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType) logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
@ -133,11 +132,13 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string)
} }
func (r *CachingResolver) redisSubscriber(ctx context.Context) { func (r *CachingResolver) redisSubscriber(ctx context.Context) {
ctx, logger := r.log(ctx)
for { for {
select { select {
case rc := <-r.redisClient.CacheChannel: case rc := <-r.redisClient.CacheChannel:
if rc != nil { if rc != nil {
r.log().Debug("Received key from redis: ", rc.Key) logger.Debug("Received key from redis: ", rc.Key)
ttl := r.adjustTTLs(rc.Response.Res.Answer) ttl := r.adjustTTLs(rc.Response.Res.Answer)
r.putInCache(rc.Key, rc.Response, ttl, false) r.putInCache(rc.Key, rc.Response, ttl, false)
} }
@ -158,7 +159,7 @@ func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
// Resolve checks if the current query should use the cache and if the result is already in // Resolve checks if the current query should use the cache and if the result is already in
// the cache and returns it or delegates to the next resolver // the cache and returns it or delegates to the next resolver
func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) { func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) {
logger := log.WithPrefix(request.Log, "caching_resolver") ctx, logger := r.log(ctx)
if !r.IsEnabled() || !isRequestCacheable(request) { if !r.IsEnabled() || !isRequestCacheable(request) {
logger.Debug("skip cache") logger.Debug("skip cache")
@ -171,7 +172,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (
cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain) cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain)
logger := logger.WithField("domain", util.Obfuscate(domain)) logger := logger.WithField("domain", util.Obfuscate(domain))
val, ttl := r.getFromCache(cacheKey) val, ttl := r.getFromCache(logger, cacheKey)
if val != nil { if val != nil {
logger.Debug("domain is cached") logger.Debug("domain is cached")
@ -200,7 +201,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (
return response, err return response, err
} }
func (r *CachingResolver) getFromCache(key string) (*dns.Msg, time.Duration) { func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) (*dns.Msg, time.Duration) {
val, ttl := r.resultCache.Get(key) val, ttl := r.resultCache.Get(key)
if val == nil { if val == nil {
return nil, 0 return nil, 0
@ -210,7 +211,7 @@ func (r *CachingResolver) getFromCache(key string) (*dns.Msg, time.Duration) {
err := res.Unpack(*val) err := res.Unpack(*val)
if err != nil { if err != nil {
r.log().Error("can't unpack cached entry. Cache malformed?", err) logger.Error("can't unpack cached entry. Cache malformed?", err)
return nil, 0 return nil, 0
} }
@ -317,7 +318,9 @@ func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{})
} }
} }
func (r *CachingResolver) FlushCaches(context.Context) { func (r *CachingResolver) FlushCaches(ctx context.Context) {
r.log().Debug("flush caches") _, logger := r.log(ctx)
logger.Debug("flush caches")
r.resultCache.Clear() r.resultCache.Clear()
} }

View File

@ -63,7 +63,7 @@ func (r *ClientNamesResolver) Resolve(ctx context.Context, request *model.Reques
clientNames := r.getClientNames(ctx, request) clientNames := r.getClientNames(ctx, request)
request.ClientNames = clientNames request.ClientNames = clientNames
request.Log = request.Log.WithField("client_names", strings.Join(clientNames, "; ")) ctx, request.Log = log.CtxWithFields(ctx, logrus.Fields{"client_names": strings.Join(clientNames, "; ")})
return r.next.Resolve(ctx, request) return r.next.Resolve(ctx, request)
} }
@ -88,7 +88,7 @@ func (r *ClientNamesResolver) getClientNames(ctx context.Context, request *model
return cpy return cpy
} }
names := r.resolveClientNames(ctx, ip, log.WithPrefix(request.Log, "client_names_resolver")) names := r.resolveClientNames(ctx, ip)
r.cache.Put(ip.String(), &names, time.Hour) r.cache.Put(ip.String(), &names, time.Hour)
@ -111,9 +111,9 @@ func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNam
} }
// tries to resolve client name from mapping, performs reverse DNS lookup otherwise // tries to resolve client name from mapping, performs reverse DNS lookup otherwise
func (r *ClientNamesResolver) resolveClientNames( func (r *ClientNamesResolver) resolveClientNames(ctx context.Context, ip net.IP) (result []string) {
ctx context.Context, ip net.IP, logger *logrus.Entry, ctx, logger := r.log(ctx)
) (result []string) {
// try client mapping first // try client mapping first
result = r.getNameFromIPMapping(ip, result) result = r.getNameFromIPMapping(ip, result)
if len(result) > 0 { if len(result) > 0 {

View File

@ -6,7 +6,6 @@ import (
"strings" "strings"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util" "github.com/0xERR0R/blocky/util"
@ -83,7 +82,7 @@ func (r *ConditionalUpstreamResolver) processRequest(
// Resolve uses the conditional resolver to resolve the query // Resolve uses the conditional resolver to resolve the query
func (r *ConditionalUpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { func (r *ConditionalUpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "conditional_resolver") ctx, logger := r.log(ctx)
if len(r.mapping) > 0 { if len(r.mapping) > 0 {
resolved, resp, err := r.processRequest(ctx, request) resolved, resp, err := r.processRequest(ctx, request)
@ -101,7 +100,7 @@ func (r *ConditionalUpstreamResolver) internalResolve(ctx context.Context, reso
req *model.Request, req *model.Request,
) (*model.Response, error) { ) (*model.Response, error) {
// internal request resolution // internal request resolution
logger := log.WithPrefix(req.Log, "conditional_resolver") ctx, logger := r.log(ctx)
req.Req.Question[0].Name = dns.Fqdn(doFQ) req.Req.Question[0].Name = dns.Fqdn(doFQ)
response, err := reso.Resolve(ctx, req) response, err := reso.Resolve(ctx, req)

View File

@ -8,7 +8,6 @@ import (
"strings" "strings"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util" "github.com/0xERR0R/blocky/util"
@ -105,28 +104,12 @@ func (r *CustomDNSResolver) handleReverseDNS(request *model.Request) *model.Resp
return nil return nil
} }
func (r *CustomDNSResolver) forwardResponse(
logger *logrus.Entry,
ctx context.Context,
request *model.Request,
) (*model.Response, error) {
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
forwardResponse, err := r.next.Resolve(ctx, request)
if err != nil {
return nil, err
}
return forwardResponse, nil
}
func (r *CustomDNSResolver) processRequest( func (r *CustomDNSResolver) processRequest(
ctx context.Context, ctx context.Context,
logger *logrus.Entry,
request *model.Request, request *model.Request,
resolvedCnames []string, resolvedCnames []string,
) (*model.Response, error) { ) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "custom_dns_resolver")
response := new(dns.Msg) response := new(dns.Msg)
response.SetReply(request.Req) response.SetReply(request.Req)
@ -142,7 +125,7 @@ func (r *CustomDNSResolver) processRequest(
if found { if found {
for _, entry := range entries { for _, entry := range entries {
result, err := r.processDNSEntry(ctx, request, resolvedCnames, question, entry) result, err := r.processDNSEntry(ctx, logger, request, resolvedCnames, question, entry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -169,18 +152,21 @@ func (r *CustomDNSResolver) processRequest(
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil
} }
if i := strings.Index(domain, "."); i >= 0 { if i := strings.IndexRune(domain, '.'); i >= 0 {
domain = domain[i+1:] domain = domain[i+1:]
} else { } else {
break break
} }
} }
return r.forwardResponse(logger, ctx, request) logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(ctx, request)
} }
func (r *CustomDNSResolver) processDNSEntry( func (r *CustomDNSResolver) processDNSEntry(
ctx context.Context, ctx context.Context,
logger *logrus.Entry,
request *model.Request, request *model.Request,
resolvedCnames []string, resolvedCnames []string,
question dns.Question, question dns.Question,
@ -192,7 +178,7 @@ func (r *CustomDNSResolver) processDNSEntry(
case *dns.AAAA: case *dns.AAAA:
return r.processIP(v.AAAA, question, v.Header().Ttl) return r.processIP(v.AAAA, question, v.Header().Ttl)
case *dns.CNAME: case *dns.CNAME:
return r.processCNAME(ctx, request, *v, resolvedCnames, question, v.Header().Ttl) return r.processCNAME(ctx, logger, request, *v, resolvedCnames, question, v.Header().Ttl)
} }
return nil, fmt.Errorf("unsupported customDNS RR type %T", entry) return nil, fmt.Errorf("unsupported customDNS RR type %T", entry)
@ -200,17 +186,14 @@ func (r *CustomDNSResolver) processDNSEntry(
// Resolve uses internal mapping to resolve the query // Resolve uses internal mapping to resolve the query
func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
ctx, logger := r.log(ctx)
reverseResp := r.handleReverseDNS(request) reverseResp := r.handleReverseDNS(request)
if reverseResp != nil { if reverseResp != nil {
return reverseResp, nil return reverseResp, nil
} }
resp, err := r.processRequest(ctx, request, make([]string, 0, len(r.cfg.Mapping))) return r.processRequest(ctx, logger, request, make([]string, 0, len(r.cfg.Mapping)))
if err != nil {
return nil, err
}
return resp, nil
} }
func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question, ttl uint32) (result []dns.RR, err error) { func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question, ttl uint32) (result []dns.RR, err error) {
@ -230,6 +213,7 @@ func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question, ttl uint
func (r *CustomDNSResolver) processCNAME( func (r *CustomDNSResolver) processCNAME(
ctx context.Context, ctx context.Context,
logger *logrus.Entry,
request *model.Request, request *model.Request,
targetCname dns.CNAME, targetCname dns.CNAME,
resolvedCnames []string, resolvedCnames []string,
@ -259,7 +243,7 @@ func (r *CustomDNSResolver) processCNAME(
targetRequest := newRequestWithClientID(targetWithoutDot, dns.Type(question.Qtype), clientIP, clientID) targetRequest := newRequestWithClientID(targetWithoutDot, dns.Type(question.Qtype), clientIP, clientID)
// resolve the target recursively // resolve the target recursively
targetResp, err := r.processRequest(ctx, targetRequest, cnames) targetResp, err := r.processRequest(ctx, logger, targetRequest, cnames)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -46,7 +46,8 @@ func NewHostsFileResolver(ctx context.Context,
} }
err := cfg.Loading.StartPeriodicRefresh(ctx, r.loadSources, func(err error) { err := cfg.Loading.StartPeriodicRefresh(ctx, r.loadSources, func(err error) {
r.log().WithError(err).Errorf("could not load hosts files") _, logger := r.log(ctx)
logger.WithError(err).Errorf("could not load hosts files")
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -114,6 +115,8 @@ func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request)
return r.next.Resolve(ctx, request) return r.next.Resolve(ctx, request)
} }
ctx, logger := r.log(ctx)
reverseResp := r.handleReverseDNS(request) reverseResp := r.handleReverseDNS(request)
if reverseResp != nil { if reverseResp != nil {
return reverseResp, nil return reverseResp, nil
@ -124,7 +127,7 @@ func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request)
response := r.resolve(request.Req, question, domain) response := r.resolve(request.Req, question, domain)
if response != nil { if response != nil {
r.log().WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"answer": util.AnswerToString(response.Answer), "answer": util.AnswerToString(response.Answer),
"domain": util.Obfuscate(domain), "domain": util.Obfuscate(domain),
}).Debugf("returning hosts file entry") }).Debugf("returning hosts file entry")
@ -132,7 +135,7 @@ func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request)
return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}, nil return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}, nil
} }
r.log().WithField("next_resolver", Name(r.next)).Trace("go to next resolver") logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(ctx, request) return r.next.Resolve(ctx, request)
} }
@ -157,7 +160,9 @@ func (r *HostsFileResolver) loadSources(ctx context.Context) error {
return nil return nil
} }
r.log().Debug("loading hosts files") ctx, logger := r.log(ctx)
logger.Debug("loading hosts files")
//nolint:ineffassign,staticcheck,wastedassign // keep `ctx :=` so if we use ctx in the future, we use the correct one //nolint:ineffassign,staticcheck,wastedassign // keep `ctx :=` so if we use ctx in the future, we use the correct one
consumersGrp, ctx := jobgroup.WithContext(ctx) consumersGrp, ctx := jobgroup.WithContext(ctx)
@ -220,7 +225,9 @@ func (r *HostsFileResolver) parseFile(
p := parsers.AllowErrors(parsers.HostsFile(reader), r.cfg.Loading.MaxErrorsPerSource) p := parsers.AllowErrors(parsers.HostsFile(reader), r.cfg.Loading.MaxErrorsPerSource)
p.OnErr(func(err error) { p.OnErr(func(err error) {
r.log().Warnf("error parsing %s: %s, trying to continue", opener, err) _, logger := r.log(ctx)
logger.Warnf("error parsing %s: %s, trying to continue", opener, err)
}) })
return parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error { return parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error {

View File

@ -10,7 +10,6 @@ import (
"time" "time"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util" "github.com/0xERR0R/blocky/util"
@ -148,7 +147,7 @@ func (r *ParallelBestResolver) String() string {
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result // Resolve sends the query request to multiple upstream resolvers and returns the fastest result
func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, parallelResolverType) ctx, logger := r.log(ctx)
allResolvers := *r.resolvers.Load() allResolvers := *r.resolvers.Load()

View File

@ -112,7 +112,7 @@ func (r *QueryLoggingResolver) doCleanUp() {
// Resolve logs the query, duration and the result // Resolve logs the query, duration and the result
func (r *QueryLoggingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { func (r *QueryLoggingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, queryLoggingResolverType) ctx, logger := r.log(ctx)
start := time.Now() start := time.Now()
@ -170,6 +170,8 @@ func (r *QueryLoggingResolver) createLogEntry(request *model.Request, response *
// write entry: if log directory is configured, write to log file // write entry: if log directory is configured, write to log file
func (r *QueryLoggingResolver) writeLog(ctx context.Context) { func (r *QueryLoggingResolver) writeLog(ctx context.Context) {
ctx, logger := r.log(ctx)
for { for {
select { select {
case logEntry := <-r.logChan: case logEntry := <-r.logChan:
@ -181,8 +183,9 @@ func (r *QueryLoggingResolver) writeLog(ctx context.Context) {
// if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.) // if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.)
if len(r.logChan) > halfCap { if len(r.logChan) > halfCap {
r.log().WithField("channel_len", logger.
len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds()) WithField("channel_len", len(r.logChan)).
Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds())
} }
case <-ctx.Done(): case <-ctx.Done():
return return

View File

@ -199,8 +199,22 @@ func (t *typed) String() string {
return t.Type() return t.Type()
} }
func (t *typed) log() *logrus.Entry { func (t *typed) log(ctx context.Context) (context.Context, *logrus.Entry) {
return log.PrefixedLog(t.Type()) return t.logWith(ctx, func(logger *logrus.Entry) *logrus.Entry { return logger })
}
func (t *typed) logWithFields(ctx context.Context, fields logrus.Fields) (context.Context, *logrus.Entry) {
return t.logWith(ctx, func(logger *logrus.Entry) *logrus.Entry {
return logger.WithFields(fields)
})
}
func (t *typed) logWith(ctx context.Context, wrap func(*logrus.Entry) *logrus.Entry) (context.Context, *logrus.Entry) {
return log.WrapCtx(ctx, func(logger *logrus.Entry) *logrus.Entry {
logger = log.WithPrefix(logger, t.Type())
return wrap(logger)
})
} }
// Should be embedded in a Resolver to auto-implement `config.Configurable`. // Should be embedded in a Resolver to auto-implement `config.Configurable`.
@ -223,7 +237,7 @@ func (c *configurable[T]) LogConfig(logger *logrus.Entry) {
} }
type initializable interface { type initializable interface {
log() *logrus.Entry log(context.Context) (context.Context, *logrus.Entry)
setResolvers([]*upstreamResolverStatus) setResolvers([]*upstreamResolverStatus)
} }
@ -242,7 +256,9 @@ func initGroupResolvers[T initializable](
} }
onErr := func(err error) { onErr := func(err error) {
r.log().WithError(err).Error("upstream verification error, will continue to use bootstrap DNS") _, logger := r.log(ctx)
logger.WithError(err).Error("upstream verification error, will continue to use bootstrap DNS")
} }
err := cfg.Init.Strategy.Do(ctx, init, onErr) err := cfg.Init.Strategy.Do(ctx, init, onErr)

View File

@ -6,7 +6,6 @@ import (
"strings" "strings"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util" "github.com/0xERR0R/blocky/util"
@ -59,7 +58,7 @@ func (r *RewriterResolver) LogConfig(logger *logrus.Entry) {
// Resolve uses the inner resolver to resolve the rewritten query // Resolve uses the inner resolver to resolve the rewritten query
func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "rewriter_resolver") ctx, logger := r.log(ctx)
original := request.Req original := request.Req

View File

@ -8,7 +8,6 @@ import (
"sync/atomic" "sync/atomic"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util" "github.com/0xERR0R/blocky/util"
@ -74,7 +73,7 @@ func (r *StrictResolver) String() string {
// Resolve sends the query request in a strict order to the upstream resolvers // Resolve sends the query request in a strict order to the upstream resolvers
func (r *StrictResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { func (r *StrictResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, strictResolverType) ctx, logger := r.log(ctx)
// start with first resolver // start with first resolver
for _, resolver := range *r.resolvers.Load() { for _, resolver := range *r.resolvers.Load() {

View File

@ -275,7 +275,9 @@ func NewUpstreamResolver(
r := newUpstreamResolverUnchecked(cfg, bootstrap) r := newUpstreamResolverUnchecked(cfg, bootstrap)
onErr := func(err error) { onErr := func(err error) {
r.log().WithError(err).Warn("initial resolver test failed") _, logger := r.log(ctx)
logger.WithError(err).Warn("initial resolver test failed")
} }
err := cfg.Init.Strategy.Do(ctx, r.testResolve, onErr) err := cfg.Init.Strategy.Do(ctx, r.testResolve, onErr)
@ -307,8 +309,10 @@ func (r UpstreamResolver) Upstream() config.Upstream {
return r.cfg.Upstream return r.cfg.Upstream
} }
func (r *UpstreamResolver) log() *logrus.Entry { func (r *UpstreamResolver) log(ctx context.Context) (context.Context, *logrus.Entry) {
return r.typed.log().WithField("upstream", r.cfg.String()) return r.logWithFields(ctx, logrus.Fields{
"upstream": r.cfg.String(),
})
} }
// testResolve sends a test query to verify the upstream is reachable and working // testResolve sends a test query to verify the upstream is reachable and working
@ -322,7 +326,9 @@ func (r *UpstreamResolver) testResolve(ctx context.Context) error {
} }
// Resolve calls external resolver // Resolve calls external resolver
func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) {
ctx, logger := r.log(ctx)
ips, err := r.bootstrap.UpstreamIPs(ctx, r) ips, err := r.bootstrap.UpstreamIPs(ctx, r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -347,7 +353,7 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request)
} }
resp = response resp = response
r.logResponse(request, response, ip, rtt) r.logResponse(logger, request, response, ip, rtt)
return nil return nil
}, },
@ -358,7 +364,7 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request)
retry.LastErrorOnly(true), retry.LastErrorOnly(true),
retry.RetryIf(isTimeout), retry.RetryIf(isTimeout),
retry.OnRetry(func(n uint, err error) { retry.OnRetry(func(n uint, err error) {
r.log().WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"upstream": r.cfg.String(), "upstream": r.cfg.String(),
"upstream_ip": ip.String(), "upstream_ip": ip.String(),
"question": util.QuestionToString(request.Req.Question), "question": util.QuestionToString(request.Req.Question),
@ -374,8 +380,10 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request)
return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.cfg)}, nil return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.cfg)}, nil
} }
func (r *UpstreamResolver) logResponse(request *model.Request, resp *dns.Msg, ip net.IP, rtt time.Duration) { func (r *UpstreamResolver) logResponse(
r.log().WithFields(logrus.Fields{ logger *logrus.Entry, request *model.Request, resp *dns.Msg, ip net.IP, rtt time.Duration,
) {
logger.WithFields(logrus.Fields{
"answer": util.AnswerToString(resp.Answer), "answer": util.AnswerToString(resp.Answer),
"return_code": dns.RcodeToString[resp.Rcode], "return_code": dns.RcodeToString[resp.Rcode],
"upstream": r.cfg.String(), "upstream": r.cfg.String(),

View File

@ -7,7 +7,6 @@ import (
"strings" "strings"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util" "github.com/0xERR0R/blocky/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -106,9 +105,9 @@ func (r *UpstreamTreeResolver) String() string {
} }
func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, upstreamTreeResolverType) ctx, logger := r.log(ctx)
group := r.upstreamGroupByClient(request) group := r.upstreamGroupByClient(logger, request)
// delegate request to group resolver // delegate request to group resolver
logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver") logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver")
@ -116,7 +115,7 @@ func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Reque
return r.branches[group].Resolve(ctx, request) return r.branches[group].Resolve(ctx, request)
} }
func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string { func (r *UpstreamTreeResolver) upstreamGroupByClient(logger *logrus.Entry, request *model.Request) string {
groups := make([]string, 0, len(r.branches)) groups := make([]string, 0, len(r.branches))
clientIP := request.ClientIP.String() clientIP := request.ClientIP.String()
@ -145,7 +144,7 @@ func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) str
if len(groups) > 0 { if len(groups) > 0 {
if len(groups) > 1 { if len(groups) > 1 {
request.Log.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"clientNames": request.ClientNames, "clientNames": request.ClientNames,
"clientIP": clientIP, "clientIP": clientIP,
"groups": groups, "groups": groups,

View File

@ -563,19 +563,22 @@ func extractClientIDFromHost(hostName string) string {
} }
func newRequest( func newRequest(
ctx context.Context,
clientIP net.IP, clientID string, clientIP net.IP, clientID string,
protocol model.RequestProtocol, request *dns.Msg, protocol model.RequestProtocol, request *dns.Msg,
) *model.Request { ) (context.Context, *model.Request) {
return &model.Request{ ctx, logger := log.CtxWithFields(ctx, logrus.Fields{
"question": util.QuestionToString(request.Question),
"client_ip": clientIP,
})
return ctx, &model.Request{
ClientIP: clientIP, ClientIP: clientIP,
RequestClientID: clientID, RequestClientID: clientID,
Protocol: protocol, Protocol: protocol,
Req: request, Req: request,
Log: log.Log().WithFields(logrus.Fields{ Log: logger,
"question": util.QuestionToString(request.Question), RequestTS: time.Now(),
"client_ip": clientIP,
}),
RequestTS: time.Now(),
} }
} }
@ -594,7 +597,7 @@ func (s *Server) OnRequest(
hostName = con.ConnectionState().ServerName hostName = con.ConnectionState().ServerName
} }
req := newRequest(clientIP, extractClientIDFromHost(hostName), protocol, request) ctx, req := newRequest(ctx, clientIP, extractClientIDFromHost(hostName), protocol, request)
response, err := s.resolve(ctx, req) response, err := s.resolve(ctx, req)
if err != nil { if err != nil {

View File

@ -147,9 +147,9 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
clientID = extractClientIDFromHost(req.Host) clientID = extractClientIDFromHost(req.Host)
} }
r := newRequest(util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg) ctx, r := newRequest(req.Context(), util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg)
resResponse, err := s.resolve(req.Context(), r) resResponse, err := s.resolve(ctx, r)
if err != nil { if err != nil {
logAndResponseWithError(err, "unable to process query: ", rw) logAndResponseWithError(err, "unable to process query: ", rw)
@ -173,7 +173,7 @@ func (s *Server) Query(
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type, ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
) (*model.Response, error) { ) (*model.Response, error) {
dnsRequest := util.NewMsgWithQuestion(question, qType) dnsRequest := util.NewMsgWithQuestion(question, qType)
r := newRequest(clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest) ctx, r := newRequest(ctx, clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest)
return s.resolve(ctx, r) return s.resolve(ctx, r)
} }