refactor(server): setup TLS listeners manually to remove `ServeTLS` use

This commit is contained in:
ThinkChaos 2024-04-02 20:10:30 -04:00
parent c6e3de4ae0
commit 17b2a94a64
No known key found for this signature in database
2 changed files with 25 additions and 15 deletions

View File

@ -50,7 +50,6 @@ type Server struct {
queryResolver resolver.ChainedResolver queryResolver resolver.ChainedResolver
cfg *config.Config cfg *config.Config
httpMux *chi.Mux httpMux *chi.Mux
tlsCfg *tls.Config
} }
func logger() *logrus.Entry { func logger() *logrus.Entry {
@ -117,8 +116,6 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
} }
// NewServer creates new server instance with passed config // NewServer creates new server instance with passed config
//
//nolint:funlen
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) { func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
var tlsCfg *tls.Config var tlsCfg *tls.Config
@ -134,7 +131,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return nil, fmt.Errorf("server creation failed: %w", err) return nil, fmt.Errorf("server creation failed: %w", err)
} }
httpListeners, httpsListeners, err := createHTTPListeners(cfg) httpListeners, httpsListeners, err := createHTTPListeners(cfg, tlsCfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -165,7 +162,6 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
cfg: cfg, cfg: cfg,
httpListeners: httpListeners, httpListeners: httpListeners,
httpsListeners: httpsListeners, httpsListeners: httpsListeners,
tlsCfg: tlsCfg,
} }
server.printConfiguration() server.printConfiguration()
@ -211,13 +207,15 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error
return dnsServers, err.ErrorOrNil() return dnsServers, err.ErrorOrNil()
} }
func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []net.Listener, err error) { func createHTTPListeners(
httpListeners, err = newListeners("http", cfg.Ports.HTTP) cfg *config.Config, tlsCfg *tls.Config,
) (httpListeners, httpsListeners []net.Listener, err error) {
httpListeners, err = newTCPListeners("http", cfg.Ports.HTTP)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
httpsListeners, err = newListeners("https", cfg.Ports.HTTPS) httpsListeners, err = newTLSListeners("https", cfg.Ports.HTTPS, tlsCfg)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -225,7 +223,7 @@ func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []ne
return httpListeners, httpsListeners, nil return httpListeners, httpsListeners, nil
} }
func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) { func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
listeners := make([]net.Listener, 0, len(addresses)) listeners := make([]net.Listener, 0, len(addresses))
for _, address := range addresses { for _, address := range addresses {
@ -240,6 +238,19 @@ func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener,
return listeners, nil return listeners, nil
} }
func newTLSListeners(proto string, addresses config.ListenConfig, tlsCfg *tls.Config) ([]net.Listener, error) {
listeners, err := newTCPListeners(proto, addresses)
if err != nil {
return nil, err
}
for i, inner := range listeners {
listeners[i] = tls.NewListener(inner, tlsCfg)
}
return listeners, nil
}
func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) { func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
return &dns.Server{ return &dns.Server{
Addr: address, Addr: address,
@ -521,10 +532,9 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
ReadTimeout: readTimeout, ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout, ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout, WriteTimeout: writeTimeout,
TLSConfig: s.tlsCfg,
} }
if err := server.ServeTLS(listener, "", ""); err != nil { if err := server.Serve(listener); err != nil {
errCh <- fmt.Errorf("start https listener failed: %w", err) errCh <- fmt.Errorf("start https listener failed: %w", err)
} }
}() }()

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/base64" "encoding/base64"
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -741,11 +740,12 @@ var _ = Describe("Running DNS server", func() {
cfg.KeyFile = "" cfg.KeyFile = ""
cfg.CertFile = "" cfg.CertFile = ""
cfg.Ports = config.Ports{ cfg.Ports = config.Ports{
HTTPS: []string{fmt.Sprintf(":%d", GetIntPort(httpsBasePort)+100)}, HTTPS: []string{":0"},
} }
sut, err := NewServer(ctx, &cfg)
sut, err := newTLSConfig(&cfg)
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(sut.tlsCfg.Certificates).ShouldNot(BeEmpty()) Expect(sut.Certificates).ShouldNot(BeEmpty())
}) })
}) })
}) })