From 0a47eaad09bedd448eda32be396d1caf8630ea46 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Sat, 27 Jan 2024 16:19:07 -0500 Subject: [PATCH] feat: add a unique ID (`req_id`) to all logs related to a request --- server/server.go | 90 ++++++++++++++++++++++++++------------ server/server_endpoints.go | 59 +++++++++---------------- 2 files changed, 82 insertions(+), 67 deletions(-) diff --git a/server/server.go b/server/server.go index 18771784..35c1f641 100644 --- a/server/server.go +++ b/server/server.go @@ -28,6 +28,7 @@ import ( "github.com/0xERR0R/blocky/redis" "github.com/0xERR0R/blocky/resolver" "github.com/0xERR0R/blocky/util" + "github.com/google/uuid" "github.com/hashicorp/go-multierror" "github.com/go-chi/chi/v5" @@ -417,15 +418,11 @@ func createQueryResolver( } func (s *Server) registerDNSHandlers(ctx context.Context) { - wrappedOnRequest := func(w dns.ResponseWriter, request *dns.Msg) { - ip, proto := resolveClientIPAndProtocol(w.RemoteAddr()) - - s.OnRequest(ctx, w, ip, proto, request) - } - for _, server := range s.dnsServers { handler := server.Handler.(*dns.ServeMux) - handler.HandleFunc(".", wrappedOnRequest) + handler.HandleFunc(".", func(w dns.ResponseWriter, m *dns.Msg) { + s.OnRequest(ctx, w, m) + }) handler.HandleFunc("healthcheck.blocky", func(w dns.ResponseWriter, m *dns.Msg) { s.OnHealthCheck(ctx, w, m) }) @@ -570,11 +567,18 @@ func newRequest( protocol model.RequestProtocol, request *dns.Msg, ) (context.Context, *model.Request) { ctx, logger := log.CtxWithFields(ctx, logrus.Fields{ + "req_id": uuid.New().String(), "question": util.QuestionToString(request.Question), "client_ip": clientIP, }) - return ctx, &model.Request{ + logger.WithFields(logrus.Fields{ + "client_request_id": request.Id, + "client_id": clientID, + "protocol": protocol, + }).Trace("new incoming request") + + req := model.Request{ ClientIP: clientIP, RequestClientID: clientID, Protocol: protocol, @@ -582,31 +586,58 @@ func newRequest( Log: logger, RequestTS: time.Now(), } + + return ctx, &req +} + +func newRequestFromDNS(ctx context.Context, rw dns.ResponseWriter, msg *dns.Msg) (context.Context, *model.Request) { + var ( + clientIP net.IP + protocol model.RequestProtocol + ) + + if rw != nil { + clientIP, protocol = resolveClientIPAndProtocol(rw.RemoteAddr()) + } + + var clientID string + if con, ok := rw.(dns.ConnectionStater); ok && con.ConnectionState() != nil { + clientID = extractClientIDFromHost(con.ConnectionState().ServerName) + } + + return newRequest(ctx, clientIP, clientID, protocol, msg) +} + +func newRequestFromHTTP(ctx context.Context, req *http.Request, msg *dns.Msg) (context.Context, *model.Request) { + protocol := model.RequestProtocolTCP + clientIP := util.HTTPClientIP(req) + + clientID := chi.URLParam(req, "clientID") + if clientID == "" { + clientID = extractClientIDFromHost(req.Host) + } + + return newRequest(ctx, clientIP, clientID, protocol, msg) } // OnRequest will be executed if a new DNS request is received -func (s *Server) OnRequest( - ctx context.Context, w dns.ResponseWriter, - clientIP net.IP, protocol model.RequestProtocol, - request *dns.Msg, -) { - logger().Debug("new request") +func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, msg *dns.Msg) { + ctx, request := newRequestFromDNS(ctx, w, msg) - var hostName string + s.handleReq(ctx, request, w) +} - con, ok := w.(dns.ConnectionStater) - if ok && con.ConnectionState() != nil { - hostName = con.ConnectionState().ServerName - } +type msgWriter interface { + WriteMsg(msg *dns.Msg) error +} - ctx, req := newRequest(ctx, clientIP, extractClientIDFromHost(hostName), protocol, request) - - response, err := s.resolve(ctx, req) +func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) { + response, err := s.resolve(ctx, request) if err != nil { - logger().Error("error on processing request:", err) + log.FromCtx(ctx).Error("error on processing request:", err) m := new(dns.Msg) - m.SetRcode(request, dns.RcodeServerFailure) + m.SetRcode(request.Req, dns.RcodeServerFailure) err := w.WriteMsg(m) util.LogOnError(ctx, "can't write message: ", err) } else { @@ -634,7 +665,7 @@ func (s *Server) resolve(ctx context.Context, request *model.Request) (response m := new(dns.Msg) m.SetRcode(request.Req, dns.RcodeFormatError) - request.Log.Error("query has no questions") + log.FromCtx(ctx).Error("query has no questions") response = &model.Response{Res: m, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"} default: @@ -688,10 +719,11 @@ func (s *Server) OnHealthCheck(ctx context.Context, w dns.ResponseWriter, reques } func resolveClientIPAndProtocol(addr net.Addr) (ip net.IP, protocol model.RequestProtocol) { - if t, ok := addr.(*net.UDPAddr); ok { - return t.IP, model.RequestProtocolUDP - } else if t, ok := addr.(*net.TCPAddr); ok { - return t.IP, model.RequestProtocolTCP + switch a := addr.(type) { + case *net.UDPAddr: + return a.IP, model.RequestProtocolUDP + case *net.TCPAddr: + return a.IP, model.RequestProtocolTCP } return nil, model.RequestProtocolUDP diff --git a/server/server_endpoints.go b/server/server_endpoints.go index ebbe34da..54f43a12 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -132,7 +132,7 @@ func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request s.processDohMessage(rawMsg, rw, req) } -func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *http.Request) { +func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) { msg := new(dns.Msg) if err := msg.Unpack(rawMsg); err != nil { logger().Error("can't deserialize message: ", err) @@ -141,57 +141,40 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h return } - rw.Header().Set("content-type", dnsContentType) + ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg) - writeErr := func(err error) { - log.Log().Error(err) + s.handleReq(ctx, dnsReq, httpMsgWriter{rw}) +} - msg := new(dns.Msg) - msg.SetRcode(msg, dns.RcodeServerFailure) +type httpMsgWriter struct { + rw http.ResponseWriter +} - buff, err := msg.Pack() - if err != nil { - return - } - - // https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1 - rw.WriteHeader(http.StatusOK) - - _, _ = rw.Write(buff) - } - - clientID := chi.URLParam(req, "clientID") - if clientID == "" { - clientID = extractClientIDFromHost(req.Host) - } - - ctx, r := newRequest(req.Context(), util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg) - - resResponse, err := s.resolve(ctx, r) +func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error { + b, err := msg.Pack() if err != nil { - writeErr(fmt.Errorf("unable to process query: %w", err)) - - return + return err } - b, err := resResponse.Res.Pack() - if err != nil { - writeErr(fmt.Errorf("can't serialize message: %w", err)) + r.rw.Header().Set("content-type", dnsContentType) - return - } + // https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1 + r.rw.WriteHeader(http.StatusOK) - _, err = rw.Write(b) - log.Log().Error(fmt.Errorf("can't write response: %w", err)) + _, err = r.rw.Write(b) + + return err } func (s *Server) Query( ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type, ) (*model.Response, error) { - dnsRequest := util.NewMsgWithQuestion(question, qType) - ctx, r := newRequest(ctx, clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest) + msg := util.NewMsgWithQuestion(question, qType) + clientID := extractClientIDFromHost(serverHost) - return s.resolve(ctx, r) + ctx, req := newRequest(ctx, clientIP, clientID, model.RequestProtocolTCP, msg) + + return s.resolve(ctx, req) } func createHTTPSRouter(cfg *config.Config) *chi.Mux {