fix(server): use RCode=ServFail instead of HTTP 500 for internal errors

RFC 8484 Section 4.2.1:
> A successful HTTP response with a 2xx status code (see
> Section 6.3 of RFC7231) is used for any valid DNS response,
> regardless of the DNS response code.  For example, a successful 2xx
> HTTP status code is used even with a DNS message whose DNS response
> code indicates failure, such as SERVFAIL or NXDOMAIN.
https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
This commit is contained in:
ThinkChaos 2024-01-27 19:36:32 -05:00
parent 3fcf379df7
commit 4919ffac0d
3 changed files with 65 additions and 11 deletions

View File

@ -615,8 +615,12 @@ func (s *Server) OnRequest(
}
}
func (s *Server) resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
var response *model.Response
func (s *Server) resolve(ctx context.Context, request *model.Request) (response *model.Response, rerr error) {
defer func() {
if val := recover(); val != nil {
rerr = fmt.Errorf("panic occurred: %v", val)
}
}()
contextUpstreamTimeoutMultiplier := 100
timeoutDuration := time.Duration(contextUpstreamTimeoutMultiplier) * s.cfg.Upstreams.Timeout.ToDuration()

View File

@ -134,7 +134,6 @@ func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request
func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *http.Request) {
msg := new(dns.Msg)
if err := msg.Unpack(rawMsg); err != nil {
logger().Error("can't deserialize message: ", err)
http.Error(rw, err.Error(), http.StatusBadRequest)
@ -142,6 +141,25 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
return
}
rw.Header().Set("content-type", dnsContentType)
writeErr := func(err error) {
log.Log().Error(err)
msg := new(dns.Msg)
msg.SetRcode(msg, dns.RcodeServerFailure)
buff, err := msg.Pack()
if err != nil {
return
}
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write(buff)
}
clientID := chi.URLParam(req, "clientID")
if clientID == "" {
clientID = extractClientIDFromHost(req.Host)
@ -151,22 +169,20 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
resResponse, err := s.resolve(ctx, r)
if err != nil {
logAndResponseWithError(err, "unable to process query: ", rw)
writeErr(fmt.Errorf("unable to process query: %w", err))
return
}
b, err := resResponse.Res.Pack()
if err != nil {
logAndResponseWithError(err, "can't serialize message: ", rw)
writeErr(fmt.Errorf("can't serialize message: %w", err))
return
}
rw.Header().Set("content-type", dnsContentType)
_, err = rw.Write(b)
logAndResponseWithError(err, "can't write response: ", rw)
log.Log().Error(fmt.Errorf("can't write response: %w", err))
}
func (s *Server) Query(

View File

@ -529,8 +529,8 @@ var _ = Describe("Running DNS server", func() {
Expect(resp).Should(HaveHTTPStatus(http.StatusUnsupportedMediaType))
})
})
When("Internal error occurs", func() {
It("should return 'Internal server error'", func() {
When("DNS error occurs", func() {
It("should return 'ServFail'", func() {
msg = util.NewMsgWithQuestion("error.", A)
rawDNSMessage, err := msg.Pack()
Expect(err).Should(Succeed())
@ -540,7 +540,41 @@ var _ = Describe("Running DNS server", func() {
Expect(err).Should(Succeed())
DeferCleanup(resp.Body.Close)
Expect(resp).Should(HaveHTTPStatus(http.StatusInternalServerError))
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
body, err := io.ReadAll(resp.Body)
Expect(err).Should(Succeed())
msg := new(dns.Msg)
Expect(msg.Unpack(body)).Should(Succeed())
Expect(msg.Rcode).Should(Equal(dns.RcodeServerFailure))
})
})
When("Internal error occurs", func() {
BeforeEach(func() {
bak := sut.queryResolver
sut.queryResolver = nil // trigger a panic
DeferCleanup(func() { sut.queryResolver = bak })
})
It("should return 'ServFail'", func() {
msg = util.NewMsgWithQuestion("error.", A)
rawDNSMessage, err := msg.Pack()
Expect(err).Should(Succeed())
resp, err = http.Post(queryURL,
"application/dns-message", bytes.NewReader(rawDNSMessage))
Expect(err).Should(Succeed())
DeferCleanup(resp.Body.Close)
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
body, err := io.ReadAll(resp.Body)
Expect(err).Should(Succeed())
msg := new(dns.Msg)
Expect(msg.Unpack(body)).Should(Succeed())
Expect(msg.Rcode).Should(Equal(dns.RcodeServerFailure))
})
})
})