From 35b1c16878031b8e1bea2ff0164f761a263cc10e Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Tue, 2 Apr 2024 20:11:10 -0400 Subject: [PATCH] refactor(server): deduplicate HTTP server setup with new `httpServer` --- server/http.go | 48 ++++++++++++++++++++++++++ server/server.go | 89 ++++++++++++++++++++---------------------------- 2 files changed, 85 insertions(+), 52 deletions(-) create mode 100644 server/http.go diff --git a/server/http.go b/server/http.go new file mode 100644 index 00000000..78f7fe0d --- /dev/null +++ b/server/http.go @@ -0,0 +1,48 @@ +package server + +import ( + "context" + "net" + "net/http" + "time" +) + +type httpServer struct { + inner http.Server + + name string +} + +func newHTTPServer(name string, handler http.Handler) *httpServer { + const ( + readHeaderTimeout = 20 * time.Second + readTimeout = 20 * time.Second + writeTimeout = 20 * time.Second + ) + + return &httpServer{ + inner: http.Server{ + ReadTimeout: readTimeout, + ReadHeaderTimeout: readHeaderTimeout, + WriteTimeout: writeTimeout, + + Handler: handler, + }, + + name: name, + } +} + +func (s *httpServer) String() string { + return s.name +} + +func (s *httpServer) Serve(ctx context.Context, l net.Listener) error { + go func() { + <-ctx.Done() + + s.inner.Close() + }() + + return s.inner.Serve(l) +} diff --git a/server/server.go b/server/server.go index 8765af83..844734a7 100644 --- a/server/server.go +++ b/server/server.go @@ -27,6 +27,7 @@ import ( "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/redis" "github.com/0xERR0R/blocky/resolver" + "github.com/0xERR0R/blocky/util" "github.com/google/uuid" "github.com/hashicorp/go-multierror" @@ -44,12 +45,11 @@ const ( // Server controls the endpoints for DNS and HTTP type Server struct { - dnsServers []*dns.Server - httpListeners []net.Listener - httpsListeners []net.Listener - queryResolver resolver.ChainedResolver - cfg *config.Config - httpMux *chi.Mux + dnsServers []*dns.Server + queryResolver resolver.ChainedResolver + cfg *config.Config + + servers map[net.Listener]*httpServer } func logger() *logrus.Entry { @@ -116,6 +116,8 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) { } // NewServer creates new server instance with passed config +// +//nolint:funlen func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) { var tlsCfg *tls.Config @@ -157,11 +159,11 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err } server = &Server{ - dnsServers: dnsServers, - queryResolver: queryResolver, - cfg: cfg, - httpListeners: httpListeners, - httpsListeners: httpsListeners, + dnsServers: dnsServers, + queryResolver: queryResolver, + cfg: cfg, + + servers: make(map[net.Listener]*httpServer), } server.printConfiguration() @@ -173,8 +175,24 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err return nil, err } - server.httpMux = createHTTPRouter(cfg, openAPIImpl) - server.registerDoHEndpoints(server.httpMux) + httpRouter := createHTTPRouter(cfg, openAPIImpl) + server.registerDoHEndpoints(httpRouter) + + if len(cfg.Ports.HTTP) != 0 { + srv := newHTTPServer("http", httpRouter) + + for _, l := range httpListeners { + server.servers[l] = srv + } + } + + if len(cfg.Ports.HTTPS) != 0 { + srv := newHTTPServer("https", httpRouter) + + for _, l := range httpsListeners { + server.servers[l] = srv + } + } return server, err } @@ -480,12 +498,6 @@ func toMB(b uint64) uint64 { return b / bytesInKB / bytesInKB } -const ( - readHeaderTimeout = 20 * time.Second - readTimeout = 20 * time.Second - writeTimeout = 20 * time.Second -) - // Start starts the server func (s *Server) Start(ctx context.Context, errCh chan<- error) { logger().Info("Starting server") @@ -500,42 +512,15 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) { }() } - for i, listener := range s.httpListeners { - listener := listener - address := s.cfg.Ports.HTTP[i] + for listener, srv := range s.servers { + listener, srv := listener, srv go func() { - logger().Infof("http server is up and running on addr/port %s", address) + logger().Infof("%s server is up and running on addr/port %s", srv, listener.Addr()) - srv := &http.Server{ - ReadTimeout: readTimeout, - ReadHeaderTimeout: readHeaderTimeout, - WriteTimeout: writeTimeout, - Handler: s.httpMux, - } - - if err := srv.Serve(listener); err != nil { - errCh <- fmt.Errorf("start http listener failed: %w", err) - } - }() - } - - for i, listener := range s.httpsListeners { - listener := listener - address := s.cfg.Ports.HTTPS[i] - - go func() { - logger().Infof("https server is up and running on addr/port %s", address) - - server := http.Server{ - Handler: s.httpMux, - ReadTimeout: readTimeout, - ReadHeaderTimeout: readHeaderTimeout, - WriteTimeout: writeTimeout, - } - - if err := server.Serve(listener); err != nil { - errCh <- fmt.Errorf("start https listener failed: %w", err) + err := srv.Serve(ctx, listener) + if err != nil { + errCh <- fmt.Errorf("%s on %s: %w", srv, listener.Addr(), err) } }() }