refactor(server): simplify HTTP router setup

This commit is contained in:
ThinkChaos 2024-03-29 19:29:40 -04:00
parent 7d510d009b
commit c389a4a0f4
No known key found for this signature in database
2 changed files with 20 additions and 47 deletions

View File

@ -50,7 +50,6 @@ type Server struct {
queryResolver resolver.ChainedResolver
cfg *config.Config
httpMux *chi.Mux
httpsMux *chi.Mux
cert tls.Certificate
}
@ -117,19 +116,11 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return nil, fmt.Errorf("server creation failed: %w", err)
}
httpRouter := createHTTPRouter(cfg)
httpsRouter := createHTTPSRouter(cfg)
httpListeners, httpsListeners, err := createHTTPListeners(cfg)
if err != nil {
return nil, err
}
if len(httpListeners) != 0 || len(httpsListeners) != 0 {
metrics.Start(httpRouter, cfg.Prometheus)
metrics.Start(httpsRouter, cfg.Prometheus)
}
metrics.RegisterEventListeners()
bootstrap, err := resolver.NewBootstrap(ctx, cfg)
@ -156,25 +147,20 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
cfg: cfg,
httpListeners: httpListeners,
httpsListeners: httpsListeners,
httpMux: httpRouter,
httpsMux: httpsRouter,
cert: cert,
}
server.printConfiguration()
server.registerDNSHandlers(ctx)
err = server.registerAPIEndpoints(httpRouter)
openAPIImpl, err := server.createOpenAPIInterfaceImpl()
if err != nil {
return nil, err
}
err = server.registerAPIEndpoints(httpsRouter)
if err != nil {
return nil, err
}
server.httpMux = createHTTPRouter(cfg, openAPIImpl)
server.registerDoHEndpoints(server.httpMux)
return server, err
}
@ -518,7 +504,7 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
logger().Infof("https server is up and running on addr/port %s", address)
server := http.Server{
Handler: s.httpsMux,
Handler: s.httpMux,
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,

View File

@ -10,6 +10,7 @@ import (
"net/http"
"time"
"github.com/0xERR0R/blocky/metrics"
"github.com/0xERR0R/blocky/resolver"
"github.com/0xERR0R/blocky/api"
@ -37,10 +38,13 @@ const (
func secureHeader(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("strict-transport-security", "max-age=63072000")
w.Header().Set("x-frame-options", "DENY")
w.Header().Set("x-content-type-options", "nosniff")
w.Header().Set("x-xss-protection", "1; mode=block")
if r.TLS != nil {
w.Header().Set("strict-transport-security", "max-age=63072000")
w.Header().Set("x-frame-options", "DENY")
w.Header().Set("x-content-type-options", "nosniff")
w.Header().Set("x-xss-protection", "1; mode=block")
}
next.ServeHTTP(w, r)
})
}
@ -64,24 +68,15 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e
return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil
}
func (s *Server) registerAPIEndpoints(router *chi.Mux) error {
func (s *Server) registerDoHEndpoints(router *chi.Mux) {
const pathDohQuery = "/dns-query"
openAPIImpl, err := s.createOpenAPIInterfaceImpl()
if err != nil {
return err
}
api.RegisterOpenAPIEndpoints(router, openAPIImpl)
router.Get(pathDohQuery, s.dohGetRequestHandler)
router.Get(pathDohQuery+"/", s.dohGetRequestHandler)
router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler)
router.Post(pathDohQuery, s.dohPostRequestHandler)
router.Post(pathDohQuery+"/", s.dohPostRequestHandler)
router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler)
return nil
}
func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) {
@ -177,27 +172,15 @@ func (s *Server) Query(
return s.resolve(ctx, req)
}
func createHTTPSRouter(cfg *config.Config) *chi.Mux {
func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux {
router := chi.NewRouter()
configureSecureHeaderHandler(router)
registerHandlers(cfg, router)
return router
}
func createHTTPRouter(cfg *config.Config) *chi.Mux {
router := chi.NewRouter()
registerHandlers(cfg, router)
return router
}
func registerHandlers(cfg *config.Config, router *chi.Mux) {
configureCorsHandler(router)
api.RegisterOpenAPIEndpoints(router, openAPIImpl)
configureDebugHandler(router)
configureDocsHandler(router)
@ -205,6 +188,10 @@ func registerHandlers(cfg *config.Config, router *chi.Mux) {
configureStaticAssetsHandler(router)
configureRootHandler(cfg, router)
metrics.Start(router, cfg.Prometheus)
return router
}
func configureDocsHandler(router *chi.Mux) {