mirror of https://github.com/0xERR0R/blocky.git
set content-type header for HTTP endpoints (#581)
This commit is contained in:
parent
3b620102a7
commit
a903565cb0
|
@ -12,6 +12,11 @@ import (
|
|||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
contentTypeHeader = "content-type"
|
||||
jsonContentType = "application/json"
|
||||
)
|
||||
|
||||
// BlockingControl interface to control the blocking status
|
||||
type BlockingControl interface {
|
||||
EnableBlocking()
|
||||
|
@ -57,7 +62,8 @@ func registerListRefreshEndpoints(router chi.Router, refresher ListRefresher) {
|
|||
// @Tags lists
|
||||
// @Success 200 "Lists were reloaded"
|
||||
// @Router /lists/refresh [post]
|
||||
func (l *ListRefreshEndpoint) apiListRefresh(_ http.ResponseWriter, _ *http.Request) {
|
||||
func (l *ListRefreshEndpoint) apiListRefresh(rw http.ResponseWriter, _ *http.Request) {
|
||||
rw.Header().Set(contentTypeHeader, jsonContentType)
|
||||
l.refresher.RefreshLists()
|
||||
}
|
||||
|
||||
|
@ -75,10 +81,12 @@ func registerBlockingEndpoints(router chi.Router, control BlockingControl) {
|
|||
// @Tags blocking
|
||||
// @Success 200 "Blocking is enabled"
|
||||
// @Router /blocking/enable [get]
|
||||
func (s *BlockingEndpoint) apiBlockingEnable(_ http.ResponseWriter, _ *http.Request) {
|
||||
func (s *BlockingEndpoint) apiBlockingEnable(rw http.ResponseWriter, _ *http.Request) {
|
||||
log.Log().Info("enabling blocking...")
|
||||
|
||||
s.control.EnableBlocking()
|
||||
|
||||
rw.Header().Set(contentTypeHeader, jsonContentType)
|
||||
}
|
||||
|
||||
// apiBlockingDisable is the http endpoint to disable the blocking status
|
||||
|
@ -98,6 +106,8 @@ func (s *BlockingEndpoint) apiBlockingDisable(rw http.ResponseWriter, req *http.
|
|||
err error
|
||||
)
|
||||
|
||||
rw.Header().Set(contentTypeHeader, jsonContentType)
|
||||
|
||||
// parse duration from query parameter
|
||||
durationParam := req.URL.Query().Get("duration")
|
||||
if len(durationParam) > 0 {
|
||||
|
@ -132,6 +142,8 @@ func (s *BlockingEndpoint) apiBlockingDisable(rw http.ResponseWriter, req *http.
|
|||
func (s *BlockingEndpoint) apiBlockingStatus(rw http.ResponseWriter, _ *http.Request) {
|
||||
status := s.control.BlockingStatus()
|
||||
|
||||
rw.Header().Set(contentTypeHeader, jsonContentType)
|
||||
|
||||
response, err := json.Marshal(status)
|
||||
util.LogOnError("unable to marshal response ", err)
|
||||
|
||||
|
|
|
@ -51,9 +51,9 @@ var _ = Describe("API tests", func() {
|
|||
r := &ListRefreshMock{}
|
||||
sut := &ListRefreshEndpoint{refresher: r}
|
||||
It("should trigger the list refresh", func() {
|
||||
httpCode, _ := DoGetRequest("/api/lists/refresh", sut.apiListRefresh)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
Expect(r.refreshTriggered).Should(BeTrue())
|
||||
resp, _ := DoGetRequest("/api/lists/refresh", sut.apiListRefresh)
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -74,18 +74,18 @@ var _ = Describe("API tests", func() {
|
|||
It("should disable blocking resolver", func() {
|
||||
By("Calling Rest API to deactivate", func() {
|
||||
|
||||
httpCode, _ := DoGetRequest("/api/blocking/disable", sut.apiBlockingDisable)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
Expect(bc.enabled).Should(BeFalse())
|
||||
resp, _ := DoGetRequest("/api/blocking/disable", sut.apiBlockingDisable)
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
When("Disable blocking is called with a wrong parameter", func() {
|
||||
It("Should return http bad request as return code", func() {
|
||||
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=xyz", sut.apiBlockingDisable)
|
||||
|
||||
Expect(httpCode).Should(Equal(http.StatusBadRequest))
|
||||
resp, _ := DoGetRequest("/api/blocking/disable?duration=xyz", sut.apiBlockingDisable)
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusBadRequest))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -96,8 +96,9 @@ var _ = Describe("API tests", func() {
|
|||
})
|
||||
|
||||
By("Calling Rest API to deactivate blocking for 0.5 sec", func() {
|
||||
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
resp, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
|
||||
})
|
||||
|
||||
By("ensure that the blocking is disabled", func() {
|
||||
|
@ -110,13 +111,15 @@ var _ = Describe("API tests", func() {
|
|||
When("Blocking status is called", func() {
|
||||
It("should return correct status", func() {
|
||||
By("enable blocking via API", func() {
|
||||
httpCode, _ := DoGetRequest("/api/blocking/enable", sut.apiBlockingEnable)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
resp, _ := DoGetRequest("/api/blocking/enable", sut.apiBlockingEnable)
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
|
||||
})
|
||||
|
||||
By("Query blocking status via API should return 'enabled'", func() {
|
||||
httpCode, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
resp, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
|
||||
var result BlockingStatus
|
||||
err := json.NewDecoder(body).Decode(&result)
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -125,13 +128,15 @@ var _ = Describe("API tests", func() {
|
|||
})
|
||||
|
||||
By("disable blocking via API", func() {
|
||||
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
resp, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
|
||||
})
|
||||
|
||||
By("Query blocking status via API again should return 'disabled'", func() {
|
||||
httpCode, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
resp, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
|
||||
|
||||
var result BlockingStatus
|
||||
err := json.NewDecoder(body).Decode(&result)
|
||||
|
|
|
@ -40,7 +40,8 @@ func TestServer(data string) *httptest.Server {
|
|||
}
|
||||
|
||||
// DoGetRequest performs a GET request
|
||||
func DoGetRequest(url string, fn func(w http.ResponseWriter, r *http.Request)) (code int, body *bytes.Buffer) {
|
||||
func DoGetRequest(url string,
|
||||
fn func(w http.ResponseWriter, r *http.Request)) (*httptest.ResponseRecorder, *bytes.Buffer) {
|
||||
r, _ := http.NewRequest("GET", url, nil)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
@ -48,7 +49,7 @@ func DoGetRequest(url string, fn func(w http.ResponseWriter, r *http.Request)) (
|
|||
|
||||
handler.ServeHTTP(rr, r)
|
||||
|
||||
return rr.Code, rr.Body
|
||||
return rr, rr.Body
|
||||
}
|
||||
|
||||
// BeDNSRecord returns new dns matcher
|
||||
|
|
|
@ -25,9 +25,12 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
dohMessageLimit = 512
|
||||
dnsContentType = "application/dns-message"
|
||||
corsMaxAge = 5 * time.Minute
|
||||
dohMessageLimit = 512
|
||||
contentTypeHeader = "content-type"
|
||||
dnsContentType = "application/dns-message"
|
||||
jsonContentType = "application/json"
|
||||
htmlContentType = "text/html; charset=UTF-8"
|
||||
corsMaxAge = 5 * time.Minute
|
||||
)
|
||||
|
||||
func secureHeader(next http.Handler) http.Handler {
|
||||
|
@ -173,6 +176,9 @@ func extractIP(r *http.Request) string {
|
|||
// @Router /query [post]
|
||||
func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
|
||||
var queryRequest api.QueryRequest
|
||||
|
||||
rw.Header().Set(contentTypeHeader, jsonContentType)
|
||||
|
||||
err := json.NewDecoder(req.Body).Decode(&queryRequest)
|
||||
|
||||
if err != nil {
|
||||
|
@ -253,7 +259,7 @@ func createRouter(cfg *config.Config) *chi.Mux {
|
|||
|
||||
func configureRootHandler(cfg *config.Config, router *chi.Mux) {
|
||||
router.Get("/", func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.Header().Set("content-type", dnsContentType)
|
||||
writer.Header().Set(contentTypeHeader, htmlContentType)
|
||||
t := template.New("index")
|
||||
_, _ = t.Parse(web.IndexTmpl)
|
||||
|
||||
|
|
|
@ -290,18 +290,19 @@ var _ = Describe("Running DNS server", func() {
|
|||
Describe("Prometheus endpoint", func() {
|
||||
When("Prometheus URL is called", func() {
|
||||
It("should return prometheus data", func() {
|
||||
r, err := http.Get("http://localhost:4000/metrics")
|
||||
resp, err := http.Get("http://localhost:4000/metrics")
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(r.StatusCode).Should(Equal(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
})
|
||||
})
|
||||
})
|
||||
Describe("Root endpoint", func() {
|
||||
When("Root URL is called", func() {
|
||||
It("should return root page", func() {
|
||||
r, err := http.Get("http://localhost:4000/")
|
||||
resp, err := http.Get("http://localhost:4000/")
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(r.StatusCode).Should(Equal(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "text/html; charset=UTF-8"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -321,7 +322,8 @@ var _ = Describe("Running DNS server", func() {
|
|||
Expect(err).Should(Succeed())
|
||||
defer resp.Body.Close()
|
||||
|
||||
Expect(resp.StatusCode).Should(Equal(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/json"))
|
||||
|
||||
var result api.QueryResult
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
|
@ -384,6 +386,8 @@ var _ = Describe("Running DNS server", func() {
|
|||
defer resp.Body.Close()
|
||||
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/dns-message"))
|
||||
|
||||
rawMsg, err := ioutil.ReadAll(resp.Body)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
|
@ -446,6 +450,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
Expect(err).Should(Succeed())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/dns-message"))
|
||||
rawMsg, err := ioutil.ReadAll(resp.Body)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
|
@ -465,6 +470,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
Expect(err).Should(Succeed())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
|
||||
Expect(resp).Should(HaveHTTPHeaderWithValue("Content-type", "application/dns-message"))
|
||||
rawMsg, err := ioutil.ReadAll(resp.Body)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
|
|
Loading…
Reference in New Issue