refactor(server): deduplicate HTTP server setup with new `httpServer`

This commit is contained in:
ThinkChaos 2024-04-02 20:11:10 -04:00
parent 17b2a94a64
commit 35b1c16878
No known key found for this signature in database
2 changed files with 85 additions and 52 deletions

48
server/http.go Normal file
View File

@ -0,0 +1,48 @@
package server
import (
"context"
"net"
"net/http"
"time"
)
type httpServer struct {
inner http.Server
name string
}
func newHTTPServer(name string, handler http.Handler) *httpServer {
const (
readHeaderTimeout = 20 * time.Second
readTimeout = 20 * time.Second
writeTimeout = 20 * time.Second
)
return &httpServer{
inner: http.Server{
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,
Handler: handler,
},
name: name,
}
}
func (s *httpServer) String() string {
return s.name
}
func (s *httpServer) Serve(ctx context.Context, l net.Listener) error {
go func() {
<-ctx.Done()
s.inner.Close()
}()
return s.inner.Serve(l)
}

View File

@ -27,6 +27,7 @@ import (
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/resolver"
"github.com/0xERR0R/blocky/util"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
@ -44,12 +45,11 @@ const (
// Server controls the endpoints for DNS and HTTP
type Server struct {
dnsServers []*dns.Server
httpListeners []net.Listener
httpsListeners []net.Listener
queryResolver resolver.ChainedResolver
cfg *config.Config
httpMux *chi.Mux
dnsServers []*dns.Server
queryResolver resolver.ChainedResolver
cfg *config.Config
servers map[net.Listener]*httpServer
}
func logger() *logrus.Entry {
@ -116,6 +116,8 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
}
// NewServer creates new server instance with passed config
//
//nolint:funlen
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
var tlsCfg *tls.Config
@ -157,11 +159,11 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
}
server = &Server{
dnsServers: dnsServers,
queryResolver: queryResolver,
cfg: cfg,
httpListeners: httpListeners,
httpsListeners: httpsListeners,
dnsServers: dnsServers,
queryResolver: queryResolver,
cfg: cfg,
servers: make(map[net.Listener]*httpServer),
}
server.printConfiguration()
@ -173,8 +175,24 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return nil, err
}
server.httpMux = createHTTPRouter(cfg, openAPIImpl)
server.registerDoHEndpoints(server.httpMux)
httpRouter := createHTTPRouter(cfg, openAPIImpl)
server.registerDoHEndpoints(httpRouter)
if len(cfg.Ports.HTTP) != 0 {
srv := newHTTPServer("http", httpRouter)
for _, l := range httpListeners {
server.servers[l] = srv
}
}
if len(cfg.Ports.HTTPS) != 0 {
srv := newHTTPServer("https", httpRouter)
for _, l := range httpsListeners {
server.servers[l] = srv
}
}
return server, err
}
@ -480,12 +498,6 @@ func toMB(b uint64) uint64 {
return b / bytesInKB / bytesInKB
}
const (
readHeaderTimeout = 20 * time.Second
readTimeout = 20 * time.Second
writeTimeout = 20 * time.Second
)
// Start starts the server
func (s *Server) Start(ctx context.Context, errCh chan<- error) {
logger().Info("Starting server")
@ -500,42 +512,15 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
}()
}
for i, listener := range s.httpListeners {
listener := listener
address := s.cfg.Ports.HTTP[i]
for listener, srv := range s.servers {
listener, srv := listener, srv
go func() {
logger().Infof("http server is up and running on addr/port %s", address)
logger().Infof("%s server is up and running on addr/port %s", srv, listener.Addr())
srv := &http.Server{
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,
Handler: s.httpMux,
}
if err := srv.Serve(listener); err != nil {
errCh <- fmt.Errorf("start http listener failed: %w", err)
}
}()
}
for i, listener := range s.httpsListeners {
listener := listener
address := s.cfg.Ports.HTTPS[i]
go func() {
logger().Infof("https server is up and running on addr/port %s", address)
server := http.Server{
Handler: s.httpMux,
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,
}
if err := server.Serve(listener); err != nil {
errCh <- fmt.Errorf("start https listener failed: %w", err)
err := srv.Serve(ctx, listener)
if err != nil {
errCh <- fmt.Errorf("%s on %s: %w", srv, listener.Addr(), err)
}
}()
}