diff --git a/server/server.go b/server/server.go index 3383dd7a..a42c6df8 100644 --- a/server/server.go +++ b/server/server.go @@ -50,7 +50,6 @@ type Server struct { queryResolver resolver.ChainedResolver cfg *config.Config httpMux *chi.Mux - httpsMux *chi.Mux cert tls.Certificate } @@ -117,19 +116,11 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err return nil, fmt.Errorf("server creation failed: %w", err) } - httpRouter := createHTTPRouter(cfg) - httpsRouter := createHTTPSRouter(cfg) - httpListeners, httpsListeners, err := createHTTPListeners(cfg) if err != nil { return nil, err } - if len(httpListeners) != 0 || len(httpsListeners) != 0 { - metrics.Start(httpRouter, cfg.Prometheus) - metrics.Start(httpsRouter, cfg.Prometheus) - } - metrics.RegisterEventListeners() bootstrap, err := resolver.NewBootstrap(ctx, cfg) @@ -156,25 +147,20 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err cfg: cfg, httpListeners: httpListeners, httpsListeners: httpsListeners, - httpMux: httpRouter, - httpsMux: httpsRouter, cert: cert, } server.printConfiguration() server.registerDNSHandlers(ctx) - err = server.registerAPIEndpoints(httpRouter) + openAPIImpl, err := server.createOpenAPIInterfaceImpl() if err != nil { return nil, err } - err = server.registerAPIEndpoints(httpsRouter) - - if err != nil { - return nil, err - } + server.httpMux = createHTTPRouter(cfg, openAPIImpl) + server.registerDoHEndpoints(server.httpMux) return server, err } @@ -518,7 +504,7 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) { logger().Infof("https server is up and running on addr/port %s", address) server := http.Server{ - Handler: s.httpsMux, + Handler: s.httpMux, ReadTimeout: readTimeout, ReadHeaderTimeout: readHeaderTimeout, WriteTimeout: writeTimeout, diff --git a/server/server_endpoints.go b/server/server_endpoints.go index 54f43a12..582a2112 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -10,6 +10,7 @@ import ( "net/http" "time" + "github.com/0xERR0R/blocky/metrics" "github.com/0xERR0R/blocky/resolver" "github.com/0xERR0R/blocky/api" @@ -37,10 +38,13 @@ const ( func secureHeader(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("strict-transport-security", "max-age=63072000") - w.Header().Set("x-frame-options", "DENY") - w.Header().Set("x-content-type-options", "nosniff") - w.Header().Set("x-xss-protection", "1; mode=block") + if r.TLS != nil { + w.Header().Set("strict-transport-security", "max-age=63072000") + w.Header().Set("x-frame-options", "DENY") + w.Header().Set("x-content-type-options", "nosniff") + w.Header().Set("x-xss-protection", "1; mode=block") + } + next.ServeHTTP(w, r) }) } @@ -64,24 +68,15 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil } -func (s *Server) registerAPIEndpoints(router *chi.Mux) error { +func (s *Server) registerDoHEndpoints(router *chi.Mux) { const pathDohQuery = "/dns-query" - openAPIImpl, err := s.createOpenAPIInterfaceImpl() - if err != nil { - return err - } - - api.RegisterOpenAPIEndpoints(router, openAPIImpl) - router.Get(pathDohQuery, s.dohGetRequestHandler) router.Get(pathDohQuery+"/", s.dohGetRequestHandler) router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler) router.Post(pathDohQuery, s.dohPostRequestHandler) router.Post(pathDohQuery+"/", s.dohPostRequestHandler) router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler) - - return nil } func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) { @@ -177,27 +172,15 @@ func (s *Server) Query( return s.resolve(ctx, req) } -func createHTTPSRouter(cfg *config.Config) *chi.Mux { +func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux { router := chi.NewRouter() configureSecureHeaderHandler(router) - registerHandlers(cfg, router) - - return router -} - -func createHTTPRouter(cfg *config.Config) *chi.Mux { - router := chi.NewRouter() - - registerHandlers(cfg, router) - - return router -} - -func registerHandlers(cfg *config.Config, router *chi.Mux) { configureCorsHandler(router) + api.RegisterOpenAPIEndpoints(router, openAPIImpl) + configureDebugHandler(router) configureDocsHandler(router) @@ -205,6 +188,10 @@ func registerHandlers(cfg *config.Config, router *chi.Mux) { configureStaticAssetsHandler(router) configureRootHandler(cfg, router) + + metrics.Start(router, cfg.Prometheus) + + return router } func configureDocsHandler(router *chi.Mux) {