mirror of https://github.com/0xERR0R/blocky.git
refactor(server): deduplicate HTTP server setup with new `httpServer`
This commit is contained in:
parent
17b2a94a64
commit
35b1c16878
|
@ -0,0 +1,48 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type httpServer struct {
|
||||
inner http.Server
|
||||
|
||||
name string
|
||||
}
|
||||
|
||||
func newHTTPServer(name string, handler http.Handler) *httpServer {
|
||||
const (
|
||||
readHeaderTimeout = 20 * time.Second
|
||||
readTimeout = 20 * time.Second
|
||||
writeTimeout = 20 * time.Second
|
||||
)
|
||||
|
||||
return &httpServer{
|
||||
inner: http.Server{
|
||||
ReadTimeout: readTimeout,
|
||||
ReadHeaderTimeout: readHeaderTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
|
||||
Handler: handler,
|
||||
},
|
||||
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *httpServer) String() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
func (s *httpServer) Serve(ctx context.Context, l net.Listener) error {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
s.inner.Close()
|
||||
}()
|
||||
|
||||
return s.inner.Serve(l)
|
||||
}
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/redis"
|
||||
"github.com/0xERR0R/blocky/resolver"
|
||||
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
|
@ -44,12 +45,11 @@ const (
|
|||
|
||||
// Server controls the endpoints for DNS and HTTP
|
||||
type Server struct {
|
||||
dnsServers []*dns.Server
|
||||
httpListeners []net.Listener
|
||||
httpsListeners []net.Listener
|
||||
queryResolver resolver.ChainedResolver
|
||||
cfg *config.Config
|
||||
httpMux *chi.Mux
|
||||
dnsServers []*dns.Server
|
||||
queryResolver resolver.ChainedResolver
|
||||
cfg *config.Config
|
||||
|
||||
servers map[net.Listener]*httpServer
|
||||
}
|
||||
|
||||
func logger() *logrus.Entry {
|
||||
|
@ -116,6 +116,8 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
|
|||
}
|
||||
|
||||
// NewServer creates new server instance with passed config
|
||||
//
|
||||
//nolint:funlen
|
||||
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
|
||||
var tlsCfg *tls.Config
|
||||
|
||||
|
@ -157,11 +159,11 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
|
|||
}
|
||||
|
||||
server = &Server{
|
||||
dnsServers: dnsServers,
|
||||
queryResolver: queryResolver,
|
||||
cfg: cfg,
|
||||
httpListeners: httpListeners,
|
||||
httpsListeners: httpsListeners,
|
||||
dnsServers: dnsServers,
|
||||
queryResolver: queryResolver,
|
||||
cfg: cfg,
|
||||
|
||||
servers: make(map[net.Listener]*httpServer),
|
||||
}
|
||||
|
||||
server.printConfiguration()
|
||||
|
@ -173,8 +175,24 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
|
|||
return nil, err
|
||||
}
|
||||
|
||||
server.httpMux = createHTTPRouter(cfg, openAPIImpl)
|
||||
server.registerDoHEndpoints(server.httpMux)
|
||||
httpRouter := createHTTPRouter(cfg, openAPIImpl)
|
||||
server.registerDoHEndpoints(httpRouter)
|
||||
|
||||
if len(cfg.Ports.HTTP) != 0 {
|
||||
srv := newHTTPServer("http", httpRouter)
|
||||
|
||||
for _, l := range httpListeners {
|
||||
server.servers[l] = srv
|
||||
}
|
||||
}
|
||||
|
||||
if len(cfg.Ports.HTTPS) != 0 {
|
||||
srv := newHTTPServer("https", httpRouter)
|
||||
|
||||
for _, l := range httpsListeners {
|
||||
server.servers[l] = srv
|
||||
}
|
||||
}
|
||||
|
||||
return server, err
|
||||
}
|
||||
|
@ -480,12 +498,6 @@ func toMB(b uint64) uint64 {
|
|||
return b / bytesInKB / bytesInKB
|
||||
}
|
||||
|
||||
const (
|
||||
readHeaderTimeout = 20 * time.Second
|
||||
readTimeout = 20 * time.Second
|
||||
writeTimeout = 20 * time.Second
|
||||
)
|
||||
|
||||
// Start starts the server
|
||||
func (s *Server) Start(ctx context.Context, errCh chan<- error) {
|
||||
logger().Info("Starting server")
|
||||
|
@ -500,42 +512,15 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
|
|||
}()
|
||||
}
|
||||
|
||||
for i, listener := range s.httpListeners {
|
||||
listener := listener
|
||||
address := s.cfg.Ports.HTTP[i]
|
||||
for listener, srv := range s.servers {
|
||||
listener, srv := listener, srv
|
||||
|
||||
go func() {
|
||||
logger().Infof("http server is up and running on addr/port %s", address)
|
||||
logger().Infof("%s server is up and running on addr/port %s", srv, listener.Addr())
|
||||
|
||||
srv := &http.Server{
|
||||
ReadTimeout: readTimeout,
|
||||
ReadHeaderTimeout: readHeaderTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
Handler: s.httpMux,
|
||||
}
|
||||
|
||||
if err := srv.Serve(listener); err != nil {
|
||||
errCh <- fmt.Errorf("start http listener failed: %w", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for i, listener := range s.httpsListeners {
|
||||
listener := listener
|
||||
address := s.cfg.Ports.HTTPS[i]
|
||||
|
||||
go func() {
|
||||
logger().Infof("https server is up and running on addr/port %s", address)
|
||||
|
||||
server := http.Server{
|
||||
Handler: s.httpMux,
|
||||
ReadTimeout: readTimeout,
|
||||
ReadHeaderTimeout: readHeaderTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
}
|
||||
|
||||
if err := server.Serve(listener); err != nil {
|
||||
errCh <- fmt.Errorf("start https listener failed: %w", err)
|
||||
err := srv.Serve(ctx, listener)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("%s on %s: %w", srv, listener.Addr(), err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue