set content-type header for HTTP endpoints (#581)

This commit is contained in:
Dimitri Herzog 2022-06-29 22:36:54 +02:00 committed by GitHub
parent 3b620102a7
commit a903565cb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 62 additions and 32 deletions

View File

@ -12,6 +12,11 @@ import (
"github.com/go-chi/chi/v5"
)
const (
contentTypeHeader = "content-type"
jsonContentType = "application/json"
)
// BlockingControl interface to control the blocking status
type BlockingControl interface {
EnableBlocking()
@ -57,7 +62,8 @@ func registerListRefreshEndpoints(router chi.Router, refresher ListRefresher) {
// @Tags lists
// @Success 200 "Lists were reloaded"
// @Router /lists/refresh [post]
func (l *ListRefreshEndpoint) apiListRefresh(_ http.ResponseWriter, _ *http.Request) {
func (l *ListRefreshEndpoint) apiListRefresh(rw http.ResponseWriter, _ *http.Request) {
rw.Header().Set(contentTypeHeader, jsonContentType)
l.refresher.RefreshLists()
}
@ -75,10 +81,12 @@ func registerBlockingEndpoints(router chi.Router, control BlockingControl) {
// @Tags blocking
// @Success 200 "Blocking is enabled"
// @Router /blocking/enable [get]
func (s *BlockingEndpoint) apiBlockingEnable(_ http.ResponseWriter, _ *http.Request) {
func (s *BlockingEndpoint) apiBlockingEnable(rw http.ResponseWriter, _ *http.Request) {
log.Log().Info("enabling blocking...")
s.control.EnableBlocking()
rw.Header().Set(contentTypeHeader, jsonContentType)
}
// apiBlockingDisable is the http endpoint to disable the blocking status
@ -98,6 +106,8 @@ func (s *BlockingEndpoint) apiBlockingDisable(rw http.ResponseWriter, req *http.
err error
)
rw.Header().Set(contentTypeHeader, jsonContentType)
// parse duration from query parameter
durationParam := req.URL.Query().Get("duration")
if len(durationParam) > 0 {
@ -132,6 +142,8 @@ func (s *BlockingEndpoint) apiBlockingDisable(rw http.ResponseWriter, req *http.
func (s *BlockingEndpoint) apiBlockingStatus(rw http.ResponseWriter, _ *http.Request) {
status := s.control.BlockingStatus()
rw.Header().Set(contentTypeHeader, jsonContentType)
response, err := json.Marshal(status)
util.LogOnError("unable to marshal response ", err)

View File

@ -51,9 +51,9 @@ var _ = Describe("API tests", func() {
r := &ListRefreshMock{}
sut := &ListRefreshEndpoint{refresher: r}
It("should trigger the list refresh", func() {
httpCode, _ := DoGetRequest("/api/lists/refresh", sut.apiListRefresh)
Expect(httpCode).Should(Equal(http.StatusOK))
Expect(r.refreshTriggered).Should(BeTrue())
resp, _ := DoGetRequest("/api/lists/refresh", sut.apiListRefresh)
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
})
})
@ -74,18 +74,18 @@ var _ = Describe("API tests", func() {
It("should disable blocking resolver", func() {
By("Calling Rest API to deactivate", func() {
httpCode, _ := DoGetRequest("/api/blocking/disable", sut.apiBlockingDisable)
Expect(httpCode).Should(Equal(http.StatusOK))
Expect(bc.enabled).Should(BeFalse())
resp, _ := DoGetRequest("/api/blocking/disable", sut.apiBlockingDisable)
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
})
})
})
When("Disable blocking is called with a wrong parameter", func() {
It("Should return http bad request as return code", func() {
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=xyz", sut.apiBlockingDisable)
Expect(httpCode).Should(Equal(http.StatusBadRequest))
resp, _ := DoGetRequest("/api/blocking/disable?duration=xyz", sut.apiBlockingDisable)
Expect(resp).Should(HaveHTTPStatus(http.StatusBadRequest))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
})
})
@ -96,8 +96,9 @@ var _ = Describe("API tests", func() {
})
By("Calling Rest API to deactivate blocking for 0.5 sec", func() {
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
Expect(httpCode).Should(Equal(http.StatusOK))
resp, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
})
By("ensure that the blocking is disabled", func() {
@ -110,13 +111,15 @@ var _ = Describe("API tests", func() {
When("Blocking status is called", func() {
It("should return correct status", func() {
By("enable blocking via API", func() {
httpCode, _ := DoGetRequest("/api/blocking/enable", sut.apiBlockingEnable)
Expect(httpCode).Should(Equal(http.StatusOK))
resp, _ := DoGetRequest("/api/blocking/enable", sut.apiBlockingEnable)
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
})
By("Query blocking status via API should return 'enabled'", func() {
httpCode, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
Expect(httpCode).Should(Equal(http.StatusOK))
resp, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
var result BlockingStatus
err := json.NewDecoder(body).Decode(&result)
Expect(err).Should(Succeed())
@ -125,13 +128,15 @@ var _ = Describe("API tests", func() {
})
By("disable blocking via API", func() {
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
Expect(httpCode).Should(Equal(http.StatusOK))
resp, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
})
By("Query blocking status via API again should return 'disabled'", func() {
httpCode, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
Expect(httpCode).Should(Equal(http.StatusOK))
resp, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
var result BlockingStatus
err := json.NewDecoder(body).Decode(&result)

View File

@ -40,7 +40,8 @@ func TestServer(data string) *httptest.Server {
}
// DoGetRequest performs a GET request
func DoGetRequest(url string, fn func(w http.ResponseWriter, r *http.Request)) (code int, body *bytes.Buffer) {
func DoGetRequest(url string,
fn func(w http.ResponseWriter, r *http.Request)) (*httptest.ResponseRecorder, *bytes.Buffer) {
r, _ := http.NewRequest("GET", url, nil)
rr := httptest.NewRecorder()
@ -48,7 +49,7 @@ func DoGetRequest(url string, fn func(w http.ResponseWriter, r *http.Request)) (
handler.ServeHTTP(rr, r)
return rr.Code, rr.Body
return rr, rr.Body
}
// BeDNSRecord returns new dns matcher

View File

@ -25,9 +25,12 @@ import (
)
const (
dohMessageLimit = 512
dnsContentType = "application/dns-message"
corsMaxAge = 5 * time.Minute
dohMessageLimit = 512
contentTypeHeader = "content-type"
dnsContentType = "application/dns-message"
jsonContentType = "application/json"
htmlContentType = "text/html; charset=UTF-8"
corsMaxAge = 5 * time.Minute
)
func secureHeader(next http.Handler) http.Handler {
@ -173,6 +176,9 @@ func extractIP(r *http.Request) string {
// @Router /query [post]
func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
var queryRequest api.QueryRequest
rw.Header().Set(contentTypeHeader, jsonContentType)
err := json.NewDecoder(req.Body).Decode(&queryRequest)
if err != nil {
@ -253,7 +259,7 @@ func createRouter(cfg *config.Config) *chi.Mux {
func configureRootHandler(cfg *config.Config, router *chi.Mux) {
router.Get("/", func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Set("content-type", dnsContentType)
writer.Header().Set(contentTypeHeader, htmlContentType)
t := template.New("index")
_, _ = t.Parse(web.IndexTmpl)

View File

@ -290,18 +290,19 @@ var _ = Describe("Running DNS server", func() {
Describe("Prometheus endpoint", func() {
When("Prometheus URL is called", func() {
It("should return prometheus data", func() {
r, err := http.Get("http://localhost:4000/metrics")
resp, err := http.Get("http://localhost:4000/metrics")
Expect(err).Should(Succeed())
Expect(r.StatusCode).Should(Equal(http.StatusOK))
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
})
})
})
Describe("Root endpoint", func() {
When("Root URL is called", func() {
It("should return root page", func() {
r, err := http.Get("http://localhost:4000/")
resp, err := http.Get("http://localhost:4000/")
Expect(err).Should(Succeed())
Expect(r.StatusCode).Should(Equal(http.StatusOK))
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "text/html; charset=UTF-8"))
})
})
})
@ -321,7 +322,8 @@ var _ = Describe("Running DNS server", func() {
Expect(err).Should(Succeed())
defer resp.Body.Close()
Expect(resp.StatusCode).Should(Equal(http.StatusOK))
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
var result api.QueryResult
err = json.NewDecoder(resp.Body).Decode(&result)
@ -384,6 +386,8 @@ var _ = Describe("Running DNS server", func() {
defer resp.Body.Close()
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/dns-message"))
rawMsg, err := ioutil.ReadAll(resp.Body)
Expect(err).Should(Succeed())
@ -446,6 +450,7 @@ var _ = Describe("Running DNS server", func() {
Expect(err).Should(Succeed())
defer resp.Body.Close()
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/dns-message"))
rawMsg, err := ioutil.ReadAll(resp.Body)
Expect(err).Should(Succeed())
@ -465,6 +470,7 @@ var _ = Describe("Running DNS server", func() {
Expect(err).Should(Succeed())
defer resp.Body.Close()
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/dns-message"))
rawMsg, err := ioutil.ReadAll(resp.Body)
Expect(err).Should(Succeed())