mirror of https://github.com/0xERR0R/blocky.git
Refactoring REST Api
This commit is contained in:
parent
e1848ecddf
commit
00fc86f91f
|
@ -0,0 +1,92 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"blocky/util"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type BlockingControl interface {
|
||||
EnableBlocking()
|
||||
DisableBlocking(duration time.Duration)
|
||||
BlockingStatus() BlockingStatus
|
||||
}
|
||||
|
||||
type BlockingEndpoint struct {
|
||||
control BlockingControl
|
||||
}
|
||||
|
||||
func RegisterEndpoint(router chi.Router, t interface{}) {
|
||||
if bc, ok := t.(BlockingControl); ok {
|
||||
registerBlockingEndpoints(router, bc)
|
||||
}
|
||||
}
|
||||
|
||||
func registerBlockingEndpoints(router chi.Router, control BlockingControl) {
|
||||
s := &BlockingEndpoint{control}
|
||||
// register API endpoints
|
||||
router.Get(BlockingEnablePath, s.apiBlockingEnable)
|
||||
router.Get(BlockingDisablePath, s.apiBlockingDisable)
|
||||
router.Get(BlockingStatusPath, s.apiBlockingStatus)
|
||||
}
|
||||
|
||||
// apiBlockingEnable is the http endpoint to enable the blocking status
|
||||
// @Summary Enable blocking
|
||||
// @Description enable the blocking status
|
||||
// @Tags blocking
|
||||
// @Success 200 "Blocking is enabled"
|
||||
// @Router /blocking/enable [get]
|
||||
func (s *BlockingEndpoint) apiBlockingEnable(_ http.ResponseWriter, _ *http.Request) {
|
||||
log.Info("enabling blocking...")
|
||||
|
||||
s.control.EnableBlocking()
|
||||
}
|
||||
|
||||
// apiBlockingDisable is the http endpoint to disable the blocking status
|
||||
// @Summary Disable blocking
|
||||
// @Description disable the blocking status
|
||||
// @Tags blocking
|
||||
// @Param duration query string false "duration of blocking (Example: 300s, 5m, 1h, 5m30s)" Format(duration)
|
||||
// @Success 200 "Blocking is disabled"
|
||||
// @Failure 400 "Wrong duration format"
|
||||
// @Router /blocking/disable [get]
|
||||
func (s *BlockingEndpoint) apiBlockingDisable(rw http.ResponseWriter, req *http.Request) {
|
||||
var (
|
||||
duration time.Duration
|
||||
err error
|
||||
)
|
||||
|
||||
// parse duration from query parameter
|
||||
durationParam := req.URL.Query().Get("duration")
|
||||
if len(durationParam) > 0 {
|
||||
duration, err = time.ParseDuration(durationParam)
|
||||
if err != nil {
|
||||
log.Errorf("wrong duration format '%s'", durationParam)
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.control.DisableBlocking(duration)
|
||||
}
|
||||
|
||||
// apiBlockingStatus is the http endpoint to get current blocking status
|
||||
// @Summary Blocking status
|
||||
// @Description get current blocking status
|
||||
// @Tags blocking
|
||||
// @Produce json
|
||||
// @Success 200 {object} api.BlockingStatus "Returns current blocking status"
|
||||
// @Router /blocking/status [get]
|
||||
func (s *BlockingEndpoint) apiBlockingStatus(rw http.ResponseWriter, _ *http.Request) {
|
||||
status := s.control.BlockingStatus()
|
||||
|
||||
response, _ := json.Marshal(status)
|
||||
_, err := rw.Write(response)
|
||||
|
||||
util.LogOnError("unable to write response ", err)
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
. "blocky/helpertest"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
)
|
||||
|
||||
type BlockingControlMock struct {
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (b *BlockingControlMock) EnableBlocking() {
|
||||
b.enabled = true
|
||||
}
|
||||
|
||||
func (b *BlockingControlMock) DisableBlocking(_ time.Duration) {
|
||||
b.enabled = false
|
||||
}
|
||||
|
||||
func (b *BlockingControlMock) BlockingStatus() BlockingStatus {
|
||||
return BlockingStatus{Enabled: b.enabled}
|
||||
}
|
||||
|
||||
var _ = Describe("API tests", func() {
|
||||
|
||||
Describe("Register router", func() {
|
||||
RegisterEndpoint(chi.NewRouter(), &BlockingControlMock{})
|
||||
})
|
||||
|
||||
Describe("Control status via API", func() {
|
||||
var (
|
||||
bc *BlockingControlMock
|
||||
sut *BlockingEndpoint
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
bc = &BlockingControlMock{enabled: true}
|
||||
sut = &BlockingEndpoint{control: bc}
|
||||
})
|
||||
|
||||
When("Disable blocking is called", func() {
|
||||
It("schould disable blocking resolver", func() {
|
||||
By("Calling Rest API to deactivate", func() {
|
||||
|
||||
httpCode, _ := DoGetRequest("/api/blocking/disable", sut.apiBlockingDisable)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
Expect(bc.enabled).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
When("Disable blocking is called with a wrong parameter", func() {
|
||||
It("Should return http bad request as return code", func() {
|
||||
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=xyz", sut.apiBlockingDisable)
|
||||
|
||||
Expect(httpCode).Should(Equal(http.StatusBadRequest))
|
||||
})
|
||||
})
|
||||
|
||||
When("Disable blocking is called with a duration parameter", func() {
|
||||
It("Should disable blocking only for the passed amount of time", func() {
|
||||
By("ensure that the blocking status is active", func() {
|
||||
Expect(bc.enabled).Should(BeTrue())
|
||||
})
|
||||
|
||||
By("Calling Rest API to deactivate blocking for 0.5 sec", func() {
|
||||
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
By("ensure that the blocking is disabled", func() {
|
||||
// now is blocking disabled
|
||||
Expect(bc.enabled).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
When("Blocking status is called", func() {
|
||||
It("should return correct status", func() {
|
||||
By("enable blocking via API", func() {
|
||||
httpCode, _ := DoGetRequest("/api/blocking/enable", sut.apiBlockingEnable)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
By("Query blocking status via API should return 'enabled'", func() {
|
||||
httpCode, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
var result BlockingStatus
|
||||
err := json.NewDecoder(body).Decode(&result)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
Expect(result.Enabled).Should(BeTrue())
|
||||
})
|
||||
|
||||
By("disable blocking via API", func() {
|
||||
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
By("Query blocking status via API again should return 'disabled'", func() {
|
||||
httpCode, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
|
||||
var result BlockingStatus
|
||||
err := json.NewDecoder(body).Decode(&result)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
Expect(result.Enabled).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,15 @@
|
|||
package api_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestResolver(t *testing.T) {
|
||||
logrus.SetLevel(logrus.WarnLevel)
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "API Suite")
|
||||
}
|
|
@ -6,17 +6,14 @@ import (
|
|||
"blocky/evt"
|
||||
"blocky/lists"
|
||||
"blocky/util"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
@ -57,34 +54,6 @@ type status struct {
|
|||
disableEnd time.Time
|
||||
}
|
||||
|
||||
func (r *BlockingResolver) enableBlocking() {
|
||||
s := r.status
|
||||
s.enableTimer.Stop()
|
||||
s.enabled = true
|
||||
|
||||
evt.Bus().Publish(evt.BlockingEnabledEvent, true)
|
||||
}
|
||||
|
||||
func (r *BlockingResolver) disableBlocking(duration time.Duration) {
|
||||
s := r.status
|
||||
s.enableTimer.Stop()
|
||||
s.enabled = false
|
||||
|
||||
evt.Bus().Publish(evt.BlockingEnabledEvent, false)
|
||||
|
||||
s.disableEnd = time.Now().Add(duration)
|
||||
|
||||
if duration == 0 {
|
||||
log.Info("disable blocking")
|
||||
} else {
|
||||
log.Infof("disable blocking for %s", duration)
|
||||
s.enableTimer = time.AfterFunc(duration, func() {
|
||||
r.enableBlocking()
|
||||
log.Info("blocking enabled again")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// checks request's question (domain name) against black and white lists
|
||||
type BlockingResolver struct {
|
||||
NextResolver
|
||||
|
@ -96,7 +65,7 @@ type BlockingResolver struct {
|
|||
status *status
|
||||
}
|
||||
|
||||
func NewBlockingResolver(router *chi.Mux, cfg config.BlockingConfig) ChainedResolver {
|
||||
func NewBlockingResolver(cfg config.BlockingConfig) ChainedResolver {
|
||||
blockHandler := createBlockHandler(cfg)
|
||||
blacklistMatcher := lists.NewListCache(lists.BLACKLIST, cfg.BlackLists, cfg.RefreshPeriod)
|
||||
whitelistMatcher := lists.NewListCache(lists.WHITELIST, cfg.WhiteLists, cfg.RefreshPeriod)
|
||||
|
@ -114,74 +83,47 @@ func NewBlockingResolver(router *chi.Mux, cfg config.BlockingConfig) ChainedReso
|
|||
},
|
||||
}
|
||||
|
||||
// register API endpoints
|
||||
router.Get(api.BlockingEnablePath, res.apiBlockingEnable)
|
||||
router.Get(api.BlockingDisablePath, res.apiBlockingDisable)
|
||||
router.Get(api.BlockingStatusPath, res.apiBlockingStatus)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// apiBlockingEnable is the http endpoint to enable the blocking status
|
||||
// @Summary Enable blocking
|
||||
// @Description enable the blocking status
|
||||
// @Tags blocking
|
||||
// @Success 200 "Blocking is enabled"
|
||||
// @Router /blocking/enable [get]
|
||||
func (r *BlockingResolver) apiBlockingEnable(_ http.ResponseWriter, _ *http.Request) {
|
||||
log.Info("enabling blocking...")
|
||||
r.enableBlocking()
|
||||
func (r *BlockingResolver) EnableBlocking() {
|
||||
s := r.status
|
||||
s.enableTimer.Stop()
|
||||
s.enabled = true
|
||||
|
||||
evt.Bus().Publish(evt.BlockingEnabledEvent, true)
|
||||
}
|
||||
|
||||
// apiBlockingStatus is the http endpoint to get current blocking status
|
||||
// @Summary Blocking status
|
||||
// @Description get current blocking status
|
||||
// @Tags blocking
|
||||
// @Produce json
|
||||
// @Success 200 {object} api.BlockingStatus "Returns current blocking status"
|
||||
// @Router /blocking/status [get]
|
||||
func (r *BlockingResolver) apiBlockingStatus(rw http.ResponseWriter, _ *http.Request) {
|
||||
func (r *BlockingResolver) DisableBlocking(duration time.Duration) {
|
||||
s := r.status
|
||||
s.enableTimer.Stop()
|
||||
s.enabled = false
|
||||
|
||||
evt.Bus().Publish(evt.BlockingEnabledEvent, false)
|
||||
|
||||
s.disableEnd = time.Now().Add(duration)
|
||||
|
||||
if duration == 0 {
|
||||
log.Info("disable blocking")
|
||||
} else {
|
||||
log.Infof("disable blocking for %s", duration)
|
||||
s.enableTimer = time.AfterFunc(duration, func() {
|
||||
r.EnableBlocking()
|
||||
log.Info("blocking enabled again")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BlockingResolver) BlockingStatus() api.BlockingStatus {
|
||||
var autoEnableDuration time.Duration
|
||||
if !r.status.enabled && r.status.disableEnd.After(time.Now()) {
|
||||
autoEnableDuration = time.Until(r.status.disableEnd)
|
||||
}
|
||||
|
||||
response, _ := json.Marshal(api.BlockingStatus{
|
||||
return api.BlockingStatus{
|
||||
Enabled: r.status.enabled,
|
||||
AutoEnableInSec: uint(autoEnableDuration.Seconds()),
|
||||
})
|
||||
_, err := rw.Write(response)
|
||||
|
||||
util.LogOnError("unable to write response ", err)
|
||||
}
|
||||
|
||||
// apiBlockingDisable is the http endpoint to disable the blocking status
|
||||
// @Summary Disable blocking
|
||||
// @Description disable the blocking status
|
||||
// @Tags blocking
|
||||
// @Param duration query string false "duration of blocking (Example: 300s, 5m, 1h, 5m30s)" Format(duration)
|
||||
// @Success 200 "Blocking is disabled"
|
||||
// @Failure 400 "Wrong duration format"
|
||||
// @Router /blocking/disable [get]
|
||||
func (r *BlockingResolver) apiBlockingDisable(rw http.ResponseWriter, req *http.Request) {
|
||||
var (
|
||||
duration time.Duration
|
||||
err error
|
||||
)
|
||||
|
||||
// parse duration from query parameter
|
||||
durationParam := req.URL.Query().Get("duration")
|
||||
if len(durationParam) > 0 {
|
||||
duration, err = time.ParseDuration(durationParam)
|
||||
if err != nil {
|
||||
log.Errorf("wrong duration format '%s'", durationParam)
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
r.disableBlocking(duration)
|
||||
}
|
||||
|
||||
// returns groups, which have only whitelist entries
|
||||
|
|
|
@ -1,19 +1,15 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"blocky/api"
|
||||
"blocky/config"
|
||||
. "blocky/evt"
|
||||
. "blocky/helpertest"
|
||||
"blocky/lists"
|
||||
"blocky/util"
|
||||
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/miekg/dns"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -63,7 +59,7 @@ badcnamedomain.com`)
|
|||
JustBeforeEach(func() {
|
||||
m = &resolverMock{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
sut = NewBlockingResolver(chi.NewRouter(), sutConfig).(*BlockingResolver)
|
||||
sut = NewBlockingResolver(sutConfig).(*BlockingResolver)
|
||||
sut.Next(m)
|
||||
})
|
||||
|
||||
|
@ -92,7 +88,7 @@ badcnamedomain.com`)
|
|||
Expect(err).Should(Succeed())
|
||||
|
||||
// recreate to trigger a reload
|
||||
sut = NewBlockingResolver(chi.NewRouter(), sutConfig).(*BlockingResolver)
|
||||
sut = NewBlockingResolver(sutConfig).(*BlockingResolver)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
|
@ -449,8 +445,7 @@ badcnamedomain.com`)
|
|||
})
|
||||
|
||||
By("Calling Rest API to deactivate", func() {
|
||||
httpCode, _ := DoGetRequest("/api/blocking/disable", sut.apiBlockingDisable)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
sut.DisableBlocking(0)
|
||||
})
|
||||
|
||||
By("perform the same query again", func() {
|
||||
|
@ -466,14 +461,6 @@ badcnamedomain.com`)
|
|||
})
|
||||
})
|
||||
|
||||
When("Disable blocking is called with a wrong parameter", func() {
|
||||
It("Should return http bad request as return code", func() {
|
||||
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=xyz", sut.apiBlockingDisable)
|
||||
|
||||
Expect(httpCode).Should(Equal(http.StatusBadRequest))
|
||||
})
|
||||
})
|
||||
|
||||
When("Disable blocking is called with a duration parameter", func() {
|
||||
It("No query should be blocked only for passed amount of time", func() {
|
||||
By("Perform query to ensure that the blocking status is active", func() {
|
||||
|
@ -487,9 +474,8 @@ badcnamedomain.com`)
|
|||
err := Bus().SubscribeOnce(BlockingEnabledEvent, func(state bool) {
|
||||
enabled = state
|
||||
})
|
||||
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
|
||||
sut.DisableBlocking(500 * time.Millisecond)
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
Expect(enabled).Should(BeFalse())
|
||||
})
|
||||
|
||||
|
@ -522,32 +508,20 @@ badcnamedomain.com`)
|
|||
When("Blocking status is called", func() {
|
||||
It("should return correct status", func() {
|
||||
By("enable blocking via API", func() {
|
||||
httpCode, _ := DoGetRequest("/api/blocking/enable", sut.apiBlockingEnable)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
sut.EnableBlocking()
|
||||
})
|
||||
|
||||
By("Query blocking status via API should return 'enabled'", func() {
|
||||
httpCode, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
var result api.BlockingStatus
|
||||
err := json.NewDecoder(body).Decode(&result)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
result := sut.BlockingStatus()
|
||||
Expect(result.Enabled).Should(BeTrue())
|
||||
})
|
||||
|
||||
By("disable blocking via API", func() {
|
||||
httpCode, _ := DoGetRequest("/api/blocking/disable?duration=500ms", sut.apiBlockingDisable)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
sut.DisableBlocking(500 * time.Millisecond)
|
||||
})
|
||||
|
||||
By("Query blocking status via API again should return 'disabled'", func() {
|
||||
httpCode, body := DoGetRequest("/api/blocking/status", sut.apiBlockingStatus)
|
||||
Expect(httpCode).Should(Equal(http.StatusOK))
|
||||
|
||||
var result api.BlockingStatus
|
||||
err := json.NewDecoder(body).Decode(&result)
|
||||
Expect(err).Should(Succeed())
|
||||
result := sut.BlockingStatus()
|
||||
|
||||
Expect(result.Enabled).Should(BeFalse())
|
||||
})
|
||||
|
@ -591,7 +565,7 @@ badcnamedomain.com`)
|
|||
|
||||
logrus.StandardLogger().ExitFunc = func(int) { fatal = true }
|
||||
|
||||
_ = NewBlockingResolver(chi.NewRouter(), config.BlockingConfig{
|
||||
_ = NewBlockingResolver(config.BlockingConfig{
|
||||
BlockType: "wrong",
|
||||
})
|
||||
|
||||
|
|
|
@ -3,8 +3,6 @@ package resolver
|
|||
import (
|
||||
"blocky/config"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
@ -13,8 +11,7 @@ var _ = Describe("Resolver", func() {
|
|||
Describe("Creating resolver chain", func() {
|
||||
When("A chain of resolvers will be created", func() {
|
||||
It("should be iterable by calling 'GetNext'", func() {
|
||||
ch := Chain(NewBlockingResolver(chi.NewRouter(),
|
||||
config.BlockingConfig{}), NewClientNamesResolver(config.ClientLookupConfig{}))
|
||||
ch := Chain(NewBlockingResolver(config.BlockingConfig{}), NewClientNamesResolver(config.ClientLookupConfig{}))
|
||||
c, ok := ch.(ChainedResolver)
|
||||
Expect(ok).Should(BeTrue())
|
||||
|
||||
|
@ -24,7 +21,7 @@ var _ = Describe("Resolver", func() {
|
|||
})
|
||||
When("'Name' will be called", func() {
|
||||
It("should return resolver name", func() {
|
||||
name := Name(NewBlockingResolver(chi.NewRouter(), config.BlockingConfig{}))
|
||||
name := Name(NewBlockingResolver(config.BlockingConfig{}))
|
||||
Expect(name).Should(Equal("BlockingResolver"))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"blocky/api"
|
||||
"blocky/config"
|
||||
"blocky/metrics"
|
||||
"blocky/resolver"
|
||||
|
@ -74,10 +75,11 @@ func NewServer(cfg *config.Config) (server *Server, err error) {
|
|||
|
||||
metrics.RegisterEventListeners()
|
||||
|
||||
queryResolver := createQueryResolver(cfg)
|
||||
server = &Server{
|
||||
udpServer: udpServer,
|
||||
tcpServer: tcpServer,
|
||||
queryResolver: createQueryResolver(cfg, router),
|
||||
queryResolver: queryResolver,
|
||||
cfg: cfg,
|
||||
httpListener: httpListener,
|
||||
httpsListener: httpsListener,
|
||||
|
@ -90,9 +92,23 @@ func NewServer(cfg *config.Config) (server *Server, err error) {
|
|||
server.registerDNSHandlers(tcpServer)
|
||||
server.registerAPIEndpoints(router)
|
||||
|
||||
registerResolverAPIEndpoints(router, queryResolver)
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func registerResolverAPIEndpoints(router chi.Router, res resolver.Resolver) {
|
||||
for res != nil {
|
||||
api.RegisterEndpoint(router, res)
|
||||
|
||||
if cr, ok := res.(resolver.ChainedResolver); ok {
|
||||
res = cr.GetNext()
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createTCPServer(address string) *dns.Server {
|
||||
tcpServer := &dns.Server{
|
||||
Addr: address,
|
||||
|
@ -119,7 +135,7 @@ func createUDPServer(address string) *dns.Server {
|
|||
return udpServer
|
||||
}
|
||||
|
||||
func createQueryResolver(cfg *config.Config, router *chi.Mux) resolver.Resolver {
|
||||
func createQueryResolver(cfg *config.Config) resolver.Resolver {
|
||||
return resolver.Chain(
|
||||
resolver.NewClientNamesResolver(cfg.ClientLookup),
|
||||
resolver.NewQueryLoggingResolver(cfg.QueryLog),
|
||||
|
@ -127,7 +143,7 @@ func createQueryResolver(cfg *config.Config, router *chi.Mux) resolver.Resolver
|
|||
resolver.NewMetricsResolver(cfg.Prometheus),
|
||||
resolver.NewConditionalUpstreamResolver(cfg.Conditional),
|
||||
resolver.NewCustomDNSResolver(cfg.CustomDNS),
|
||||
resolver.NewBlockingResolver(router, cfg.Blocking),
|
||||
resolver.NewBlockingResolver(cfg.Blocking),
|
||||
resolver.NewCachingResolver(cfg.Caching),
|
||||
resolver.NewParallelBestResolver(cfg.Upstream.ExternalResolvers),
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue