From 4919ffac0d3ae0691e8771d9ec8aca6741694c60 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Sat, 27 Jan 2024 19:36:32 -0500 Subject: [PATCH] fix(server): use RCode=ServFail instead of HTTP 500 for internal errors RFC 8484 Section 4.2.1: > A successful HTTP response with a 2xx status code (see > Section 6.3 of RFC7231) is used for any valid DNS response, > regardless of the DNS response code. For example, a successful 2xx > HTTP status code is used even with a DNS message whose DNS response > code indicates failure, such as SERVFAIL or NXDOMAIN. https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1 --- server/server.go | 8 ++++++-- server/server_endpoints.go | 28 ++++++++++++++++++++------ server/server_test.go | 40 +++++++++++++++++++++++++++++++++++--- 3 files changed, 65 insertions(+), 11 deletions(-) diff --git a/server/server.go b/server/server.go index 58c69e30..18771784 100644 --- a/server/server.go +++ b/server/server.go @@ -615,8 +615,12 @@ func (s *Server) OnRequest( } } -func (s *Server) resolve(ctx context.Context, request *model.Request) (*model.Response, error) { - var response *model.Response +func (s *Server) resolve(ctx context.Context, request *model.Request) (response *model.Response, rerr error) { + defer func() { + if val := recover(); val != nil { + rerr = fmt.Errorf("panic occurred: %v", val) + } + }() contextUpstreamTimeoutMultiplier := 100 timeoutDuration := time.Duration(contextUpstreamTimeoutMultiplier) * s.cfg.Upstreams.Timeout.ToDuration() diff --git a/server/server_endpoints.go b/server/server_endpoints.go index 0151bf60..ebbe34da 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -134,7 +134,6 @@ func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *http.Request) { msg := new(dns.Msg) - if err := msg.Unpack(rawMsg); err != nil { logger().Error("can't deserialize message: ", err) http.Error(rw, err.Error(), http.StatusBadRequest) @@ -142,6 +141,25 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h return } + rw.Header().Set("content-type", dnsContentType) + + writeErr := func(err error) { + log.Log().Error(err) + + msg := new(dns.Msg) + msg.SetRcode(msg, dns.RcodeServerFailure) + + buff, err := msg.Pack() + if err != nil { + return + } + + // https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1 + rw.WriteHeader(http.StatusOK) + + _, _ = rw.Write(buff) + } + clientID := chi.URLParam(req, "clientID") if clientID == "" { clientID = extractClientIDFromHost(req.Host) @@ -151,22 +169,20 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h resResponse, err := s.resolve(ctx, r) if err != nil { - logAndResponseWithError(err, "unable to process query: ", rw) + writeErr(fmt.Errorf("unable to process query: %w", err)) return } b, err := resResponse.Res.Pack() if err != nil { - logAndResponseWithError(err, "can't serialize message: ", rw) + writeErr(fmt.Errorf("can't serialize message: %w", err)) return } - rw.Header().Set("content-type", dnsContentType) - _, err = rw.Write(b) - logAndResponseWithError(err, "can't write response: ", rw) + log.Log().Error(fmt.Errorf("can't write response: %w", err)) } func (s *Server) Query( diff --git a/server/server_test.go b/server/server_test.go index 85ea63a5..549f3025 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -529,8 +529,8 @@ var _ = Describe("Running DNS server", func() { Expect(resp).Should(HaveHTTPStatus(http.StatusUnsupportedMediaType)) }) }) - When("Internal error occurs", func() { - It("should return 'Internal server error'", func() { + When("DNS error occurs", func() { + It("should return 'ServFail'", func() { msg = util.NewMsgWithQuestion("error.", A) rawDNSMessage, err := msg.Pack() Expect(err).Should(Succeed()) @@ -540,7 +540,41 @@ var _ = Describe("Running DNS server", func() { Expect(err).Should(Succeed()) DeferCleanup(resp.Body.Close) - Expect(resp).Should(HaveHTTPStatus(http.StatusInternalServerError)) + Expect(resp).Should(HaveHTTPStatus(http.StatusOK)) + + body, err := io.ReadAll(resp.Body) + Expect(err).Should(Succeed()) + + msg := new(dns.Msg) + Expect(msg.Unpack(body)).Should(Succeed()) + Expect(msg.Rcode).Should(Equal(dns.RcodeServerFailure)) + }) + }) + When("Internal error occurs", func() { + BeforeEach(func() { + bak := sut.queryResolver + sut.queryResolver = nil // trigger a panic + DeferCleanup(func() { sut.queryResolver = bak }) + }) + + It("should return 'ServFail'", func() { + msg = util.NewMsgWithQuestion("error.", A) + rawDNSMessage, err := msg.Pack() + Expect(err).Should(Succeed()) + + resp, err = http.Post(queryURL, + "application/dns-message", bytes.NewReader(rawDNSMessage)) + Expect(err).Should(Succeed()) + DeferCleanup(resp.Body.Close) + + Expect(resp).Should(HaveHTTPStatus(http.StatusOK)) + + body, err := io.ReadAll(resp.Body) + Expect(err).Should(Succeed()) + + msg := new(dns.Msg) + Expect(msg.Unpack(body)).Should(Succeed()) + Expect(msg.Rcode).Should(Equal(dns.RcodeServerFailure)) }) }) })