feat: add API endpoint to flush the DNS Cache (#1178)

This commit is contained in:
Dimitri Herzog 2023-09-30 22:13:01 +02:00 committed by GitHub
parent 96e812d57e
commit d77f0ed54f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 342 additions and 36 deletions

View File

@ -98,6 +98,9 @@ type ClientInterface interface {
// BlockingStatus request
BlockingStatus(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error)
// CacheFlush request
CacheFlush(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error)
// ListRefresh request
ListRefresh(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error)
@ -143,6 +146,18 @@ func (c *Client) BlockingStatus(ctx context.Context, reqEditors ...RequestEditor
return c.Client.Do(req)
}
func (c *Client) CacheFlush(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) {
req, err := NewCacheFlushRequest(c.Server)
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
if err := c.applyEditors(ctx, req, reqEditors); err != nil {
return nil, err
}
return c.Client.Do(req)
}
func (c *Client) ListRefresh(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) {
req, err := NewListRefreshRequest(c.Server)
if err != nil {
@ -298,6 +313,33 @@ func NewBlockingStatusRequest(server string) (*http.Request, error) {
return req, nil
}
// NewCacheFlushRequest generates requests for CacheFlush
func NewCacheFlushRequest(server string) (*http.Request, error) {
var err error
serverURL, err := url.Parse(server)
if err != nil {
return nil, err
}
operationPath := fmt.Sprintf("/cache/flush")
if operationPath[0] == '/' {
operationPath = "." + operationPath
}
queryURL, err := serverURL.Parse(operationPath)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", queryURL.String(), nil)
if err != nil {
return nil, err
}
return req, nil
}
// NewListRefreshRequest generates requests for ListRefresh
func NewListRefreshRequest(server string) (*http.Request, error) {
var err error
@ -417,6 +459,9 @@ type ClientWithResponsesInterface interface {
// BlockingStatusWithResponse request
BlockingStatusWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*BlockingStatusResponse, error)
// CacheFlushWithResponse request
CacheFlushWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*CacheFlushResponse, error)
// ListRefreshWithResponse request
ListRefreshWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*ListRefreshResponse, error)
@ -490,6 +535,27 @@ func (r BlockingStatusResponse) StatusCode() int {
return 0
}
type CacheFlushResponse struct {
Body []byte
HTTPResponse *http.Response
}
// Status returns HTTPResponse.Status
func (r CacheFlushResponse) Status() string {
if r.HTTPResponse != nil {
return r.HTTPResponse.Status
}
return http.StatusText(0)
}
// StatusCode returns HTTPResponse.StatusCode
func (r CacheFlushResponse) StatusCode() int {
if r.HTTPResponse != nil {
return r.HTTPResponse.StatusCode
}
return 0
}
type ListRefreshResponse struct {
Body []byte
HTTPResponse *http.Response
@ -560,6 +626,15 @@ func (c *ClientWithResponses) BlockingStatusWithResponse(ctx context.Context, re
return ParseBlockingStatusResponse(rsp)
}
// CacheFlushWithResponse request returning *CacheFlushResponse
func (c *ClientWithResponses) CacheFlushWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*CacheFlushResponse, error) {
rsp, err := c.CacheFlush(ctx, reqEditors...)
if err != nil {
return nil, err
}
return ParseCacheFlushResponse(rsp)
}
// ListRefreshWithResponse request returning *ListRefreshResponse
func (c *ClientWithResponses) ListRefreshWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*ListRefreshResponse, error) {
rsp, err := c.ListRefresh(ctx, reqEditors...)
@ -644,6 +719,22 @@ func ParseBlockingStatusResponse(rsp *http.Response) (*BlockingStatusResponse, e
return response, nil
}
// ParseCacheFlushResponse parses an HTTP response from a CacheFlushWithResponse call
func ParseCacheFlushResponse(rsp *http.Response) (*CacheFlushResponse, error) {
bodyBytes, err := io.ReadAll(rsp.Body)
defer func() { _ = rsp.Body.Close() }()
if err != nil {
return nil, err
}
response := &CacheFlushResponse{
Body: bodyBytes,
HTTPResponse: rsp,
}
return response, nil
}
// ParseListRefreshResponse parses an HTTP response from a ListRefreshWithResponse call
func ParseListRefreshResponse(rsp *http.Response) (*ListRefreshResponse, error) {
bodyBytes, err := io.ReadAll(rsp.Body)

View File

@ -23,7 +23,7 @@ type BlockingStatus struct {
Enabled bool
// Disabled group names
DisabledGroups []string
// If blocking is temporary disabled: amount of seconds until blocking will be enabled
// If blocking is temporarily disabled: amount of seconds until blocking will be enabled
AutoEnableInSec int
}
@ -43,21 +43,31 @@ type Querier interface {
Query(question string, qType dns.Type) (*model.Response, error)
}
type CacheControl interface {
FlushCaches()
}
func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
HandlerFromMuxWithBaseURL(NewStrictHandler(impl, nil), router, "/api")
}
type OpenAPIInterfaceImpl struct {
control BlockingControl
querier Querier
refresher ListRefresher
control BlockingControl
querier Querier
refresher ListRefresher
cacheControl CacheControl
}
func NewOpenAPIInterfaceImpl(control BlockingControl, querier Querier, refresher ListRefresher) *OpenAPIInterfaceImpl {
func NewOpenAPIInterfaceImpl(control BlockingControl,
querier Querier,
refresher ListRefresher,
cacheControl CacheControl,
) *OpenAPIInterfaceImpl {
return &OpenAPIInterfaceImpl{
control: control,
querier: querier,
refresher: refresher,
control: control,
querier: querier,
refresher: refresher,
cacheControl: cacheControl,
}
}
@ -145,3 +155,11 @@ func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObje
ReturnCode: dns.RcodeToString[resp.Res.Rcode],
}), nil
}
func (i *OpenAPIInterfaceImpl) CacheFlush(_ context.Context,
_ CacheFlushRequestObject,
) (CacheFlushResponseObject, error) {
i.cacheControl.FlushCaches()
return CacheFlush200Response{}, nil
}

View File

@ -28,6 +28,10 @@ type QuerierMock struct {
mock.Mock
}
type CacheControlMock struct {
mock.Mock
}
func (m *ListRefreshMock) RefreshLists() error {
args := m.Called()
@ -56,11 +60,16 @@ func (m *QuerierMock) Query(question string, qType dns.Type) (*model.Response, e
return args.Get(0).(*model.Response), args.Error(1)
}
func (m *CacheControlMock) FlushCaches() {
_ = m.Called()
}
var _ = Describe("API implementation tests", func() {
var (
blockingControlMock *BlockingControlMock
querierMock *QuerierMock
listRefreshMock *ListRefreshMock
cacheControlMock *CacheControlMock
sut *OpenAPIInterfaceImpl
)
@ -68,7 +77,8 @@ var _ = Describe("API implementation tests", func() {
blockingControlMock = &BlockingControlMock{}
querierMock = &QuerierMock{}
listRefreshMock = &ListRefreshMock{}
sut = NewOpenAPIInterfaceImpl(blockingControlMock, querierMock, listRefreshMock)
cacheControlMock = &CacheControlMock{}
sut = NewOpenAPIInterfaceImpl(blockingControlMock, querierMock, listRefreshMock, cacheControlMock)
})
AfterEach(func() {
@ -213,4 +223,16 @@ var _ = Describe("API implementation tests", func() {
})
})
})
Describe("Cache API", func() {
When("Cache flush is called", func() {
It("should return 200 on success", func() {
cacheControlMock.On("FlushCaches").Return()
resp, err := sut.CacheFlush(context.Background(), CacheFlushRequestObject{})
Expect(err).Should(Succeed())
var resp200 CacheFlush200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))
})
})
})
})

View File

@ -25,6 +25,9 @@ type ServerInterface interface {
// Blocking status
// (GET /blocking/status)
BlockingStatus(w http.ResponseWriter, r *http.Request)
// Clears the DNS response cache
// (POST /cache/flush)
CacheFlush(w http.ResponseWriter, r *http.Request)
// List refresh
// (POST /lists/refresh)
ListRefresh(w http.ResponseWriter, r *http.Request)
@ -55,6 +58,12 @@ func (_ Unimplemented) BlockingStatus(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// Clears the DNS response cache
// (POST /cache/flush)
func (_ Unimplemented) CacheFlush(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// List refresh
// (POST /lists/refresh)
func (_ Unimplemented) ListRefresh(w http.ResponseWriter, r *http.Request) {
@ -142,6 +151,21 @@ func (siw *ServerInterfaceWrapper) BlockingStatus(w http.ResponseWriter, r *http
handler.ServeHTTP(w, r.WithContext(ctx))
}
// CacheFlush operation middleware
func (siw *ServerInterfaceWrapper) CacheFlush(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.CacheFlush(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r.WithContext(ctx))
}
// ListRefresh operation middleware
func (siw *ServerInterfaceWrapper) ListRefresh(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@ -294,6 +318,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/blocking/status", wrapper.BlockingStatus)
})
r.Group(func(r chi.Router) {
r.Post(options.BaseURL+"/cache/flush", wrapper.CacheFlush)
})
r.Group(func(r chi.Router) {
r.Post(options.BaseURL+"/lists/refresh", wrapper.ListRefresh)
})
@ -361,6 +388,21 @@ func (response BlockingStatus200JSONResponse) VisitBlockingStatusResponse(w http
return json.NewEncoder(w).Encode(response)
}
type CacheFlushRequestObject struct {
}
type CacheFlushResponseObject interface {
VisitCacheFlushResponse(w http.ResponseWriter) error
}
type CacheFlush200Response struct {
}
func (response CacheFlush200Response) VisitCacheFlushResponse(w http.ResponseWriter) error {
w.WriteHeader(200)
return nil
}
type ListRefreshRequestObject struct {
}
@ -424,6 +466,9 @@ type StrictServerInterface interface {
// Blocking status
// (GET /blocking/status)
BlockingStatus(ctx context.Context, request BlockingStatusRequestObject) (BlockingStatusResponseObject, error)
// Clears the DNS response cache
// (POST /cache/flush)
CacheFlush(ctx context.Context, request CacheFlushRequestObject) (CacheFlushResponseObject, error)
// List refresh
// (POST /lists/refresh)
ListRefresh(ctx context.Context, request ListRefreshRequestObject) (ListRefreshResponseObject, error)
@ -535,6 +580,30 @@ func (sh *strictHandler) BlockingStatus(w http.ResponseWriter, r *http.Request)
}
}
// CacheFlush operation middleware
func (sh *strictHandler) CacheFlush(w http.ResponseWriter, r *http.Request) {
var request CacheFlushRequestObject
handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, request interface{}) (interface{}, error) {
return sh.ssi.CacheFlush(ctx, request.(CacheFlushRequestObject))
}
for _, middleware := range sh.middlewares {
handler = middleware(handler, "CacheFlush")
}
response, err := handler(r.Context(), w, r, request)
if err != nil {
sh.options.ResponseErrorHandlerFunc(w, r, err)
} else if validResponse, ok := response.(CacheFlushResponseObject); ok {
if err := validResponse.VisitCacheFlushResponse(w); err != nil {
sh.options.ResponseErrorHandlerFunc(w, r, err)
}
} else if response != nil {
sh.options.ResponseErrorHandlerFunc(w, r, fmt.Errorf("unexpected response type: %T", response))
}
}
// ListRefresh operation middleware
func (sh *strictHandler) ListRefresh(w http.ResponseWriter, r *http.Request) {
var request ListRefreshRequestObject

View File

@ -124,4 +124,5 @@ func (e *PrefetchingExpiringLRUCache[T]) TotalCount() int {
// Clear removes all cache entries
func (e *PrefetchingExpiringLRUCache[T]) Clear() {
e.cache.Clear()
e.prefetchingNameCache.Clear()
}

View File

@ -57,13 +57,7 @@ func enableBlocking(_ *cobra.Command, _ []string) error {
return fmt.Errorf("can't execute %w", err)
}
if resp.StatusCode() == http.StatusOK {
log.Log().Info("OK")
} else {
return fmt.Errorf("response NOK, Status: %s", resp.Status())
}
return nil
return printOkOrError(resp, string(resp.Body))
}
func disableBlocking(cmd *cobra.Command, _ []string) error {
@ -86,13 +80,7 @@ func disableBlocking(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("can't execute %w", err)
}
if resp.StatusCode() == http.StatusOK {
log.Log().Info("OK")
} else {
return fmt.Errorf("response NOK, Status: %s", resp.Status())
}
return nil
return printOkOrError(resp, string(resp.Body))
}
func statusBlocking(_ *cobra.Command, _ []string) error {

39
cmd/cache.go Normal file
View File

@ -0,0 +1,39 @@
package cmd
import (
"context"
"fmt"
"github.com/0xERR0R/blocky/api"
"github.com/spf13/cobra"
)
func newCacheCommand() *cobra.Command {
c := &cobra.Command{
Use: "cache",
Short: "Performs cache operations",
}
c.AddCommand(&cobra.Command{
Use: "flush",
Args: cobra.NoArgs,
Aliases: []string{"clear"},
Short: "Flush cache",
RunE: flushCache,
})
return c
}
func flushCache(_ *cobra.Command, _ []string) error {
client, err := api.NewClientWithResponses(apiURL())
if err != nil {
return fmt.Errorf("can't create client: %w", err)
}
resp, err := client.CacheFlushWithResponse(context.Background())
if err != nil {
return fmt.Errorf("can't execute %w", err)
}
return printOkOrError(resp, string(resp.Body))
}

51
cmd/cache_test.go Normal file
View File

@ -0,0 +1,51 @@
package cmd
import (
"net/http"
"net/http/httptest"
"github.com/sirupsen/logrus/hooks/test"
"github.com/0xERR0R/blocky/log"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Cache command", func() {
var (
ts *httptest.Server
mockFn func(w http.ResponseWriter, _ *http.Request)
loggerHook *test.Hook
)
JustBeforeEach(func() {
ts = testHTTPAPIServer(mockFn)
})
JustAfterEach(func() {
ts.Close()
})
BeforeEach(func() {
mockFn = func(w http.ResponseWriter, _ *http.Request) {}
loggerHook = test.NewGlobal()
log.Log().AddHook(loggerHook)
})
AfterEach(func() {
loggerHook.Reset()
})
Describe("flush cache", func() {
When("flush cache is called via REST", func() {
It("should flush caches", func() {
Expect(flushCache(newCacheCommand(), []string{})).Should(Succeed())
Expect(loggerHook.LastEntry().Message).Should(Equal("OK"))
})
})
When("Wrong url is used", func() {
It("Should end with error", func() {
apiPort = 0
err := flushCache(newCacheCommand(), []string{})
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("connection refused"))
})
})
})
})

View File

@ -3,11 +3,8 @@ package cmd
import (
"context"
"fmt"
"net/http"
"github.com/0xERR0R/blocky/api"
"github.com/0xERR0R/blocky/log"
"github.com/spf13/cobra"
)
@ -42,11 +39,5 @@ func refreshList(_ *cobra.Command, _ []string) error {
return fmt.Errorf("can't execute %w", err)
}
if resp.StatusCode() != http.StatusOK {
return fmt.Errorf("response NOK, %s %s", resp.Status(), string(resp.Body))
}
log.Log().Info("OK")
return nil
return printOkOrError(resp, string(resp.Body))
}

View File

@ -3,6 +3,7 @@ package cmd
import (
"fmt"
"net"
"net/http"
"os"
"strconv"
"strings"
@ -54,7 +55,8 @@ Complete documentation is available at https://github.com/0xERR0R/blocky`,
newServeCommand(),
newBlockingCommand(),
NewListsCommand(),
NewHealthcheckCommand())
NewHealthcheckCommand(),
newCacheCommand())
return c
}
@ -112,3 +114,18 @@ func Execute() {
os.Exit(1)
}
}
type codeWithStatus interface {
StatusCode() int
Status() string
}
func printOkOrError(resp codeWithStatus, body string) error {
if resp.StatusCode() == http.StatusOK {
log.Log().Info("OK")
} else {
return fmt.Errorf("response NOK, %s %s", resp.Status(), body)
}
return nil
}

View File

@ -180,6 +180,16 @@ paths:
schema:
type: string
example: Bad request
/cache/flush:
post:
operationId: cacheFlush
tags:
- cache
summary: Clears the DNS response cache
description: Removes all DNS responses from cache
responses:
'200':
description: All caches cleared
components:
schemas:
api.BlockingStatus:

View File

@ -277,3 +277,8 @@ func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{})
evt.Bus().Publish(event, val)
}
}
func (r *CachingResolver) FlushCaches() {
r.log().Debug("flush caches")
r.resultCache.Clear()
}

View File

@ -30,7 +30,6 @@ const (
dohMessageLimit = 512
contentTypeHeader = "content-type"
dnsContentType = "application/dns-message"
jsonContentType = "application/json"
htmlContentType = "text/html; charset=UTF-8"
yamlContentType = "text/yaml"
corsMaxAge = 5 * time.Minute
@ -57,7 +56,12 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e
return nil, fmt.Errorf("no refresh API implementation found %w", err)
}
return api.NewOpenAPIInterfaceImpl(bControl, s, refresher), nil
cacheControl, err := resolver.GetFromChainWithType[api.CacheControl](s.queryResolver)
if err != nil {
return nil, fmt.Errorf("no cache API implementation found %w", err)
}
return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil
}
func (s *Server) registerAPIEndpoints(router *chi.Mux) error {