Refactoring REST Api

This commit is contained in:
Dimitri Herzog 2021-02-04 21:59:41 +01:00
parent e1848ecddf
commit 00fc86f91f
7 changed files with 286 additions and 130 deletions

92
api/api_endpoints.go Normal file
View File

@ -0,0 +1,92 @@
package api
import (
"blocky/util"
"encoding/json"
"net/http"
"time"
"github.com/go-chi/chi"
log "github.com/sirupsen/logrus"
)
type BlockingControl interface {
EnableBlocking()
DisableBlocking(duration time.Duration)
BlockingStatus() BlockingStatus
}
type BlockingEndpoint struct {
control BlockingControl
}
func RegisterEndpoint(router chi.Router, t interface{}) {
if bc, ok := t.(BlockingControl); ok {
registerBlockingEndpoints(router, bc)
}
}
func registerBlockingEndpoints(router chi.Router, control BlockingControl) {
s := &BlockingEndpoint{control}
// register API endpoints
router.Get(BlockingEnablePath, s.apiBlockingEnable)
router.Get(BlockingDisablePath, s.apiBlockingDisable)
router.Get(BlockingStatusPath, s.apiBlockingStatus)
}
// apiBlockingEnable is the http endpoint to enable the blocking status
// @Summary Enable blocking
// @Description enable the blocking status
// @Tags blocking
// @Success 200 "Blocking is enabled"
// @Router /blocking/enable [get]
func (s *BlockingEndpoint) apiBlockingEnable(_ http.ResponseWriter, _ *http.Request) {
log.Info("enabling blocking...")
s.control.EnableBlocking()
}
// apiBlockingDisable is the http endpoint to disable the blocking status
// @Summary Disable blocking
// @Description disable the blocking status
// @Tags blocking
// @Param duration query string false "duration of blocking (Example: 300s, 5m, 1h, 5m30s)" Format(duration)
// @Success 200 "Blocking is disabled"
// @Failure 400 "Wrong duration format"
// @Router /blocking/disable [get]
func (s *BlockingEndpoint) apiBlockingDisable(rw http.ResponseWriter, req *http.Request) {
var (
duration time.Duration
err error
)
// parse duration from query parameter
durationParam := req.URL.Query().Get("duration")
if len(durationParam) > 0 {
duration, err = time.ParseDuration(durationParam)
if err != nil {
log.Errorf("wrong duration format '%s'", durationParam)
rw.WriteHeader(http.StatusBadRequest)
return
}
}
s.control.DisableBlocking(duration)
}
// apiBlockingStatus is the http endpoint to get current blocking status
// @Summary Blocking status
// @Description get current blocking status
// @Tags blocking
// @Produce json
// @Success 200 {object} api.BlockingStatus "Returns current blocking status"
// @Router /blocking/status [get]
func (s *BlockingEndpoint) apiBlockingStatus(rw http.ResponseWriter, _ *http.Request) {
status := s.control.BlockingStatus()
response, _ := json.Marshal(status)
_, err := rw.Write(response)
util.LogOnError("unable to write response ", err)
}

120
api/api_endpoints_test.go Normal file
View File

@ -0,0 +1,120 @@
package api
import (
. "blocky/helpertest"
"encoding/json"
"net/http"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/go-chi/chi"
)
type BlockingControlMock struct {
enabled bool
}
func (b *BlockingControlMock) EnableBlocking() {
b.enabled = true
}
func (b *BlockingControlMock) DisableBlocking(_ time.Duration) {
b.enabled = false
}
func (b *BlockingControlMock) BlockingStatus() BlockingStatus {
return BlockingStatus{Enabled: b.enabled}
}
var _ = Describe("API tests", func() {
Describe("Register router", func() {
RegisterEndpoint(chi.NewRouter(), &BlockingControlMock{})
})
Describe("Control status via API", func() {
var (
bc *BlockingControlMock
sut *BlockingEndpoint
)
BeforeEach(func() {
bc = &BlockingControlMock{enabled: true}
sut = &BlockingEndpoint{control: bc}
})
When("Disable blocking is called", func() {
It("schould disable blocking resolver", func() {
By("Calling Rest API to deactivate", func() {
httpCode, _ := DoGetRequest("/api/blocking/disable", sut.apiBlockingDisable)
Expect(httpCode).Should(Equal(http.StatusOK))
Expect(bc.enabled).Should(BeFalse())
})
})
})
When("Disable blocking is called with a wrong parameter", func() {
It("Should return http bad request as return code", func() {
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=xyz", sut.apiBlockingDisable)
Expect(httpCode).Should(Equal(http.StatusBadRequest))
})
})
When("Disable blocking is called with a duration parameter", func() {
It("Should disable blocking only for the passed amount of time", func() {
By("ensure that the blocking status is active", func() {
Expect(bc.enabled).Should(BeTrue())
})
By("Calling Rest API to deactivate blocking for 0.5 sec", func() {
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
Expect(httpCode).Should(Equal(http.StatusOK))
})
By("ensure that the blocking is disabled", func() {
// now is blocking disabled
Expect(bc.enabled).Should(BeFalse())
})
})
})
When("Blocking status is called", func() {
It("should return correct status", func() {
By("enable blocking via API", func() {
httpCode, _ := DoGetRequest("/api/blocking/enable", sut.apiBlockingEnable)
Expect(httpCode).Should(Equal(http.StatusOK))
})
By("Query blocking status via API should return 'enabled'", func() {
httpCode, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
Expect(httpCode).Should(Equal(http.StatusOK))
var result BlockingStatus
err := json.NewDecoder(body).Decode(&result)
Expect(err).Should(Succeed())
Expect(result.Enabled).Should(BeTrue())
})
By("disable blocking via API", func() {
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
Expect(httpCode).Should(Equal(http.StatusOK))
})
By("Query blocking status via API again should return 'disabled'", func() {
httpCode, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
Expect(httpCode).Should(Equal(http.StatusOK))
var result BlockingStatus
err := json.NewDecoder(body).Decode(&result)
Expect(err).Should(Succeed())
Expect(result.Enabled).Should(BeFalse())
})
})
})
})
})

15
api/api_suite_test.go Normal file
View File

@ -0,0 +1,15 @@
package api_test
import (
"testing"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/sirupsen/logrus"
)
func TestResolver(t *testing.T) {
logrus.SetLevel(logrus.WarnLevel)
RegisterFailHandler(Fail)
RunSpecs(t, "API Suite")
}

View File

@ -6,17 +6,14 @@ import (
"blocky/evt"
"blocky/lists"
"blocky/util"
"encoding/json"
"fmt"
"net"
"net/http"
"path/filepath"
"reflect"
"sort"
"strings"
"time"
"github.com/go-chi/chi"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
@ -57,34 +54,6 @@ type status struct {
disableEnd time.Time
}
func (r *BlockingResolver) enableBlocking() {
s := r.status
s.enableTimer.Stop()
s.enabled = true
evt.Bus().Publish(evt.BlockingEnabledEvent, true)
}
func (r *BlockingResolver) disableBlocking(duration time.Duration) {
s := r.status
s.enableTimer.Stop()
s.enabled = false
evt.Bus().Publish(evt.BlockingEnabledEvent, false)
s.disableEnd = time.Now().Add(duration)
if duration == 0 {
log.Info("disable blocking")
} else {
log.Infof("disable blocking for %s", duration)
s.enableTimer = time.AfterFunc(duration, func() {
r.enableBlocking()
log.Info("blocking enabled again")
})
}
}
// checks request's question (domain name) against black and white lists
type BlockingResolver struct {
NextResolver
@ -96,7 +65,7 @@ type BlockingResolver struct {
status *status
}
func NewBlockingResolver(router *chi.Mux, cfg config.BlockingConfig) ChainedResolver {
func NewBlockingResolver(cfg config.BlockingConfig) ChainedResolver {
blockHandler := createBlockHandler(cfg)
blacklistMatcher := lists.NewListCache(lists.BLACKLIST, cfg.BlackLists, cfg.RefreshPeriod)
whitelistMatcher := lists.NewListCache(lists.WHITELIST, cfg.WhiteLists, cfg.RefreshPeriod)
@ -114,74 +83,47 @@ func NewBlockingResolver(router *chi.Mux, cfg config.BlockingConfig) ChainedReso
},
}
// register API endpoints
router.Get(api.BlockingEnablePath, res.apiBlockingEnable)
router.Get(api.BlockingDisablePath, res.apiBlockingDisable)
router.Get(api.BlockingStatusPath, res.apiBlockingStatus)
return res
}
// apiBlockingEnable is the http endpoint to enable the blocking status
// @Summary Enable blocking
// @Description enable the blocking status
// @Tags blocking
// @Success 200 "Blocking is enabled"
// @Router /blocking/enable [get]
func (r *BlockingResolver) apiBlockingEnable(_ http.ResponseWriter, _ *http.Request) {
log.Info("enabling blocking...")
r.enableBlocking()
func (r *BlockingResolver) EnableBlocking() {
s := r.status
s.enableTimer.Stop()
s.enabled = true
evt.Bus().Publish(evt.BlockingEnabledEvent, true)
}
// apiBlockingStatus is the http endpoint to get current blocking status
// @Summary Blocking status
// @Description get current blocking status
// @Tags blocking
// @Produce json
// @Success 200 {object} api.BlockingStatus "Returns current blocking status"
// @Router /blocking/status [get]
func (r *BlockingResolver) apiBlockingStatus(rw http.ResponseWriter, _ *http.Request) {
func (r *BlockingResolver) DisableBlocking(duration time.Duration) {
s := r.status
s.enableTimer.Stop()
s.enabled = false
evt.Bus().Publish(evt.BlockingEnabledEvent, false)
s.disableEnd = time.Now().Add(duration)
if duration == 0 {
log.Info("disable blocking")
} else {
log.Infof("disable blocking for %s", duration)
s.enableTimer = time.AfterFunc(duration, func() {
r.EnableBlocking()
log.Info("blocking enabled again")
})
}
}
func (r *BlockingResolver) BlockingStatus() api.BlockingStatus {
var autoEnableDuration time.Duration
if !r.status.enabled && r.status.disableEnd.After(time.Now()) {
autoEnableDuration = time.Until(r.status.disableEnd)
}
response, _ := json.Marshal(api.BlockingStatus{
return api.BlockingStatus{
Enabled: r.status.enabled,
AutoEnableInSec: uint(autoEnableDuration.Seconds()),
})
_, err := rw.Write(response)
util.LogOnError("unable to write response ", err)
}
// apiBlockingDisable is the http endpoint to disable the blocking status
// @Summary Disable blocking
// @Description disable the blocking status
// @Tags blocking
// @Param duration query string false "duration of blocking (Example: 300s, 5m, 1h, 5m30s)" Format(duration)
// @Success 200 "Blocking is disabled"
// @Failure 400 "Wrong duration format"
// @Router /blocking/disable [get]
func (r *BlockingResolver) apiBlockingDisable(rw http.ResponseWriter, req *http.Request) {
var (
duration time.Duration
err error
)
// parse duration from query parameter
durationParam := req.URL.Query().Get("duration")
if len(durationParam) > 0 {
duration, err = time.ParseDuration(durationParam)
if err != nil {
log.Errorf("wrong duration format '%s'", durationParam)
rw.WriteHeader(http.StatusBadRequest)
return
}
}
r.disableBlocking(duration)
}
// returns groups, which have only whitelist entries

View File

@ -1,19 +1,15 @@
package resolver
import (
"blocky/api"
"blocky/config"
. "blocky/evt"
. "blocky/helpertest"
"blocky/lists"
"blocky/util"
"encoding/json"
"net/http"
"os"
"time"
"github.com/go-chi/chi"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@ -63,7 +59,7 @@ badcnamedomain.com`)
JustBeforeEach(func() {
m = &resolverMock{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
sut = NewBlockingResolver(chi.NewRouter(), sutConfig).(*BlockingResolver)
sut = NewBlockingResolver(sutConfig).(*BlockingResolver)
sut.Next(m)
})
@ -92,7 +88,7 @@ badcnamedomain.com`)
Expect(err).Should(Succeed())
// recreate to trigger a reload
sut = NewBlockingResolver(chi.NewRouter(), sutConfig).(*BlockingResolver)
sut = NewBlockingResolver(sutConfig).(*BlockingResolver)
time.Sleep(time.Second)
@ -449,8 +445,7 @@ badcnamedomain.com`)
})
By("Calling Rest API to deactivate", func() {
httpCode, _ := DoGetRequest("/api/blocking/disable", sut.apiBlockingDisable)
Expect(httpCode).Should(Equal(http.StatusOK))
sut.DisableBlocking(0)
})
By("perform the same query again", func() {
@ -466,14 +461,6 @@ badcnamedomain.com`)
})
})
When("Disable blocking is called with a wrong parameter", func() {
It("Should return http bad request as return code", func() {
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=xyz", sut.apiBlockingDisable)
Expect(httpCode).Should(Equal(http.StatusBadRequest))
})
})
When("Disable blocking is called with a duration parameter", func() {
It("No query should be blocked only for passed amount of time", func() {
By("Perform query to ensure that the blocking status is active", func() {
@ -487,9 +474,8 @@ badcnamedomain.com`)
err := Bus().SubscribeOnce(BlockingEnabledEvent, func(state bool) {
enabled = state
})
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
sut.DisableBlocking(500 * time.Millisecond)
Expect(err).Should(Succeed())
Expect(httpCode).Should(Equal(http.StatusOK))
Expect(enabled).Should(BeFalse())
})
@ -522,32 +508,20 @@ badcnamedomain.com`)
When("Blocking status is called", func() {
It("should return correct status", func() {
By("enable blocking via API", func() {
httpCode, _ := DoGetRequest("/api/blocking/enable", sut.apiBlockingEnable)
Expect(httpCode).Should(Equal(http.StatusOK))
sut.EnableBlocking()
})
By("Query blocking status via API should return 'enabled'", func() {
httpCode, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
Expect(httpCode).Should(Equal(http.StatusOK))
var result api.BlockingStatus
err := json.NewDecoder(body).Decode(&result)
Expect(err).Should(Succeed())
result := sut.BlockingStatus()
Expect(result.Enabled).Should(BeTrue())
})
By("disable blocking via API", func() {
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
Expect(httpCode).Should(Equal(http.StatusOK))
sut.DisableBlocking(500 * time.Millisecond)
})
By("Query blocking status via API again should return 'disabled'", func() {
httpCode, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
Expect(httpCode).Should(Equal(http.StatusOK))
var result api.BlockingStatus
err := json.NewDecoder(body).Decode(&result)
Expect(err).Should(Succeed())
result := sut.BlockingStatus()
Expect(result.Enabled).Should(BeFalse())
})
@ -591,7 +565,7 @@ badcnamedomain.com`)
logrus.StandardLogger().ExitFunc = func(int) { fatal = true }
_ = NewBlockingResolver(chi.NewRouter(), config.BlockingConfig{
_ = NewBlockingResolver(config.BlockingConfig{
BlockType: "wrong",
})

View File

@ -3,8 +3,6 @@ package resolver
import (
"blocky/config"
"github.com/go-chi/chi"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
@ -13,8 +11,7 @@ var _ = Describe("Resolver", func() {
Describe("Creating resolver chain", func() {
When("A chain of resolvers will be created", func() {
It("should be iterable by calling 'GetNext'", func() {
ch := Chain(NewBlockingResolver(chi.NewRouter(),
config.BlockingConfig{}), NewClientNamesResolver(config.ClientLookupConfig{}))
ch := Chain(NewBlockingResolver(config.BlockingConfig{}), NewClientNamesResolver(config.ClientLookupConfig{}))
c, ok := ch.(ChainedResolver)
Expect(ok).Should(BeTrue())
@ -24,7 +21,7 @@ var _ = Describe("Resolver", func() {
})
When("'Name' will be called", func() {
It("should return resolver name", func() {
name := Name(NewBlockingResolver(chi.NewRouter(), config.BlockingConfig{}))
name := Name(NewBlockingResolver(config.BlockingConfig{}))
Expect(name).Should(Equal("BlockingResolver"))
})
})

View File

@ -1,6 +1,7 @@
package server
import (
"blocky/api"
"blocky/config"
"blocky/metrics"
"blocky/resolver"
@ -74,10 +75,11 @@ func NewServer(cfg *config.Config) (server *Server, err error) {
metrics.RegisterEventListeners()
queryResolver := createQueryResolver(cfg)
server = &Server{
udpServer: udpServer,
tcpServer: tcpServer,
queryResolver: createQueryResolver(cfg, router),
queryResolver: queryResolver,
cfg: cfg,
httpListener: httpListener,
httpsListener: httpsListener,
@ -90,9 +92,23 @@ func NewServer(cfg *config.Config) (server *Server, err error) {
server.registerDNSHandlers(tcpServer)
server.registerAPIEndpoints(router)
registerResolverAPIEndpoints(router, queryResolver)
return server, nil
}
func registerResolverAPIEndpoints(router chi.Router, res resolver.Resolver) {
for res != nil {
api.RegisterEndpoint(router, res)
if cr, ok := res.(resolver.ChainedResolver); ok {
res = cr.GetNext()
} else {
return
}
}
}
func createTCPServer(address string) *dns.Server {
tcpServer := &dns.Server{
Addr: address,
@ -119,7 +135,7 @@ func createUDPServer(address string) *dns.Server {
return udpServer
}
func createQueryResolver(cfg *config.Config, router *chi.Mux) resolver.Resolver {
func createQueryResolver(cfg *config.Config) resolver.Resolver {
return resolver.Chain(
resolver.NewClientNamesResolver(cfg.ClientLookup),
resolver.NewQueryLoggingResolver(cfg.QueryLog),
@ -127,7 +143,7 @@ func createQueryResolver(cfg *config.Config, router *chi.Mux) resolver.Resolver
resolver.NewMetricsResolver(cfg.Prometheus),
resolver.NewConditionalUpstreamResolver(cfg.Conditional),
resolver.NewCustomDNSResolver(cfg.CustomDNS),
resolver.NewBlockingResolver(router, cfg.Blocking),
resolver.NewBlockingResolver(cfg.Blocking),
resolver.NewCachingResolver(cfg.Caching),
resolver.NewParallelBestResolver(cfg.Upstream.ExternalResolvers),
)