feat: add a unique ID (`req_id`) to all logs related to a request

This commit is contained in:
ThinkChaos 2024-01-27 16:19:07 -05:00
parent 4919ffac0d
commit 0a47eaad09
2 changed files with 82 additions and 67 deletions

View File

@ -28,6 +28,7 @@ import (
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/resolver"
"github.com/0xERR0R/blocky/util"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
"github.com/go-chi/chi/v5"
@ -417,15 +418,11 @@ func createQueryResolver(
}
func (s *Server) registerDNSHandlers(ctx context.Context) {
wrappedOnRequest := func(w dns.ResponseWriter, request *dns.Msg) {
ip, proto := resolveClientIPAndProtocol(w.RemoteAddr())
s.OnRequest(ctx, w, ip, proto, request)
}
for _, server := range s.dnsServers {
handler := server.Handler.(*dns.ServeMux)
handler.HandleFunc(".", wrappedOnRequest)
handler.HandleFunc(".", func(w dns.ResponseWriter, m *dns.Msg) {
s.OnRequest(ctx, w, m)
})
handler.HandleFunc("healthcheck.blocky", func(w dns.ResponseWriter, m *dns.Msg) {
s.OnHealthCheck(ctx, w, m)
})
@ -570,11 +567,18 @@ func newRequest(
protocol model.RequestProtocol, request *dns.Msg,
) (context.Context, *model.Request) {
ctx, logger := log.CtxWithFields(ctx, logrus.Fields{
"req_id": uuid.New().String(),
"question": util.QuestionToString(request.Question),
"client_ip": clientIP,
})
return ctx, &model.Request{
logger.WithFields(logrus.Fields{
"client_request_id": request.Id,
"client_id": clientID,
"protocol": protocol,
}).Trace("new incoming request")
req := model.Request{
ClientIP: clientIP,
RequestClientID: clientID,
Protocol: protocol,
@ -582,31 +586,58 @@ func newRequest(
Log: logger,
RequestTS: time.Now(),
}
return ctx, &req
}
func newRequestFromDNS(ctx context.Context, rw dns.ResponseWriter, msg *dns.Msg) (context.Context, *model.Request) {
var (
clientIP net.IP
protocol model.RequestProtocol
)
if rw != nil {
clientIP, protocol = resolveClientIPAndProtocol(rw.RemoteAddr())
}
var clientID string
if con, ok := rw.(dns.ConnectionStater); ok && con.ConnectionState() != nil {
clientID = extractClientIDFromHost(con.ConnectionState().ServerName)
}
return newRequest(ctx, clientIP, clientID, protocol, msg)
}
func newRequestFromHTTP(ctx context.Context, req *http.Request, msg *dns.Msg) (context.Context, *model.Request) {
protocol := model.RequestProtocolTCP
clientIP := util.HTTPClientIP(req)
clientID := chi.URLParam(req, "clientID")
if clientID == "" {
clientID = extractClientIDFromHost(req.Host)
}
return newRequest(ctx, clientIP, clientID, protocol, msg)
}
// OnRequest will be executed if a new DNS request is received
func (s *Server) OnRequest(
ctx context.Context, w dns.ResponseWriter,
clientIP net.IP, protocol model.RequestProtocol,
request *dns.Msg,
) {
logger().Debug("new request")
func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, msg *dns.Msg) {
ctx, request := newRequestFromDNS(ctx, w, msg)
var hostName string
s.handleReq(ctx, request, w)
}
con, ok := w.(dns.ConnectionStater)
if ok && con.ConnectionState() != nil {
hostName = con.ConnectionState().ServerName
}
type msgWriter interface {
WriteMsg(msg *dns.Msg) error
}
ctx, req := newRequest(ctx, clientIP, extractClientIDFromHost(hostName), protocol, request)
response, err := s.resolve(ctx, req)
func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) {
response, err := s.resolve(ctx, request)
if err != nil {
logger().Error("error on processing request:", err)
log.FromCtx(ctx).Error("error on processing request:", err)
m := new(dns.Msg)
m.SetRcode(request, dns.RcodeServerFailure)
m.SetRcode(request.Req, dns.RcodeServerFailure)
err := w.WriteMsg(m)
util.LogOnError(ctx, "can't write message: ", err)
} else {
@ -634,7 +665,7 @@ func (s *Server) resolve(ctx context.Context, request *model.Request) (response
m := new(dns.Msg)
m.SetRcode(request.Req, dns.RcodeFormatError)
request.Log.Error("query has no questions")
log.FromCtx(ctx).Error("query has no questions")
response = &model.Response{Res: m, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
default:
@ -688,10 +719,11 @@ func (s *Server) OnHealthCheck(ctx context.Context, w dns.ResponseWriter, reques
}
func resolveClientIPAndProtocol(addr net.Addr) (ip net.IP, protocol model.RequestProtocol) {
if t, ok := addr.(*net.UDPAddr); ok {
return t.IP, model.RequestProtocolUDP
} else if t, ok := addr.(*net.TCPAddr); ok {
return t.IP, model.RequestProtocolTCP
switch a := addr.(type) {
case *net.UDPAddr:
return a.IP, model.RequestProtocolUDP
case *net.TCPAddr:
return a.IP, model.RequestProtocolTCP
}
return nil, model.RequestProtocolUDP

View File

@ -132,7 +132,7 @@ func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request
s.processDohMessage(rawMsg, rw, req)
}
func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *http.Request) {
func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
msg := new(dns.Msg)
if err := msg.Unpack(rawMsg); err != nil {
logger().Error("can't deserialize message: ", err)
@ -141,57 +141,40 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
return
}
rw.Header().Set("content-type", dnsContentType)
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)
writeErr := func(err error) {
log.Log().Error(err)
s.handleReq(ctx, dnsReq, httpMsgWriter{rw})
}
msg := new(dns.Msg)
msg.SetRcode(msg, dns.RcodeServerFailure)
type httpMsgWriter struct {
rw http.ResponseWriter
}
buff, err := msg.Pack()
if err != nil {
return
}
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write(buff)
}
clientID := chi.URLParam(req, "clientID")
if clientID == "" {
clientID = extractClientIDFromHost(req.Host)
}
ctx, r := newRequest(req.Context(), util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg)
resResponse, err := s.resolve(ctx, r)
func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
b, err := msg.Pack()
if err != nil {
writeErr(fmt.Errorf("unable to process query: %w", err))
return
return err
}
b, err := resResponse.Res.Pack()
if err != nil {
writeErr(fmt.Errorf("can't serialize message: %w", err))
r.rw.Header().Set("content-type", dnsContentType)
return
}
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
r.rw.WriteHeader(http.StatusOK)
_, err = rw.Write(b)
log.Log().Error(fmt.Errorf("can't write response: %w", err))
_, err = r.rw.Write(b)
return err
}
func (s *Server) Query(
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
) (*model.Response, error) {
dnsRequest := util.NewMsgWithQuestion(question, qType)
ctx, r := newRequest(ctx, clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest)
msg := util.NewMsgWithQuestion(question, qType)
clientID := extractClientIDFromHost(serverHost)
return s.resolve(ctx, r)
ctx, req := newRequest(ctx, clientIP, clientID, model.RequestProtocolTCP, msg)
return s.resolve(ctx, req)
}
func createHTTPSRouter(cfg *config.Config) *chi.Mux {