diff --git a/server/server.go b/server/server.go index c8cf05ed..8765af83 100644 --- a/server/server.go +++ b/server/server.go @@ -50,7 +50,6 @@ type Server struct { queryResolver resolver.ChainedResolver cfg *config.Config httpMux *chi.Mux - tlsCfg *tls.Config } func logger() *logrus.Entry { @@ -117,8 +116,6 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) { } // NewServer creates new server instance with passed config -// -//nolint:funlen func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) { var tlsCfg *tls.Config @@ -134,7 +131,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err return nil, fmt.Errorf("server creation failed: %w", err) } - httpListeners, httpsListeners, err := createHTTPListeners(cfg) + httpListeners, httpsListeners, err := createHTTPListeners(cfg, tlsCfg) if err != nil { return nil, err } @@ -165,7 +162,6 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err cfg: cfg, httpListeners: httpListeners, httpsListeners: httpsListeners, - tlsCfg: tlsCfg, } server.printConfiguration() @@ -211,13 +207,15 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error return dnsServers, err.ErrorOrNil() } -func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []net.Listener, err error) { - httpListeners, err = newListeners("http", cfg.Ports.HTTP) +func createHTTPListeners( + cfg *config.Config, tlsCfg *tls.Config, +) (httpListeners, httpsListeners []net.Listener, err error) { + httpListeners, err = newTCPListeners("http", cfg.Ports.HTTP) if err != nil { return nil, nil, err } - httpsListeners, err = newListeners("https", cfg.Ports.HTTPS) + httpsListeners, err = newTLSListeners("https", cfg.Ports.HTTPS, tlsCfg) if err != nil { return nil, nil, err } @@ -225,7 +223,7 @@ func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []ne return httpListeners, httpsListeners, nil } -func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) { +func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) { listeners := make([]net.Listener, 0, len(addresses)) for _, address := range addresses { @@ -240,6 +238,19 @@ func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, return listeners, nil } +func newTLSListeners(proto string, addresses config.ListenConfig, tlsCfg *tls.Config) ([]net.Listener, error) { + listeners, err := newTCPListeners(proto, addresses) + if err != nil { + return nil, err + } + + for i, inner := range listeners { + listeners[i] = tls.NewListener(inner, tlsCfg) + } + + return listeners, nil +} + func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) { return &dns.Server{ Addr: address, @@ -521,10 +532,9 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) { ReadTimeout: readTimeout, ReadHeaderTimeout: readHeaderTimeout, WriteTimeout: writeTimeout, - TLSConfig: s.tlsCfg, } - if err := server.ServeTLS(listener, "", ""); err != nil { + if err := server.Serve(listener); err != nil { errCh <- fmt.Errorf("start https listener failed: %w", err) } }() diff --git a/server/server_test.go b/server/server_test.go index bb2c53d5..e88d9a14 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/base64" - "fmt" "io" "net" "net/http" @@ -741,11 +740,12 @@ var _ = Describe("Running DNS server", func() { cfg.KeyFile = "" cfg.CertFile = "" cfg.Ports = config.Ports{ - HTTPS: []string{fmt.Sprintf(":%d", GetIntPort(httpsBasePort)+100)}, + HTTPS: []string{":0"}, } - sut, err := NewServer(ctx, &cfg) + + sut, err := newTLSConfig(&cfg) Expect(err).Should(Succeed()) - Expect(sut.tlsCfg.Certificates).ShouldNot(BeEmpty()) + Expect(sut.Certificates).ShouldNot(BeEmpty()) }) }) })