refactor(server): setup TLS listeners manually to remove `ServeTLS` use

This commit is contained in:
ThinkChaos 2024-04-02 20:10:30 -04:00
parent c6e3de4ae0
commit 17b2a94a64
No known key found for this signature in database
2 changed files with 25 additions and 15 deletions

View File

@ -50,7 +50,6 @@ type Server struct {
queryResolver resolver.ChainedResolver
cfg *config.Config
httpMux *chi.Mux
tlsCfg *tls.Config
}
func logger() *logrus.Entry {
@ -117,8 +116,6 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
}
// NewServer creates new server instance with passed config
//
//nolint:funlen
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
var tlsCfg *tls.Config
@ -134,7 +131,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return nil, fmt.Errorf("server creation failed: %w", err)
}
httpListeners, httpsListeners, err := createHTTPListeners(cfg)
httpListeners, httpsListeners, err := createHTTPListeners(cfg, tlsCfg)
if err != nil {
return nil, err
}
@ -165,7 +162,6 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
cfg: cfg,
httpListeners: httpListeners,
httpsListeners: httpsListeners,
tlsCfg: tlsCfg,
}
server.printConfiguration()
@ -211,13 +207,15 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error
return dnsServers, err.ErrorOrNil()
}
func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []net.Listener, err error) {
httpListeners, err = newListeners("http", cfg.Ports.HTTP)
func createHTTPListeners(
cfg *config.Config, tlsCfg *tls.Config,
) (httpListeners, httpsListeners []net.Listener, err error) {
httpListeners, err = newTCPListeners("http", cfg.Ports.HTTP)
if err != nil {
return nil, nil, err
}
httpsListeners, err = newListeners("https", cfg.Ports.HTTPS)
httpsListeners, err = newTLSListeners("https", cfg.Ports.HTTPS, tlsCfg)
if err != nil {
return nil, nil, err
}
@ -225,7 +223,7 @@ func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []ne
return httpListeners, httpsListeners, nil
}
func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
listeners := make([]net.Listener, 0, len(addresses))
for _, address := range addresses {
@ -240,6 +238,19 @@ func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener,
return listeners, nil
}
func newTLSListeners(proto string, addresses config.ListenConfig, tlsCfg *tls.Config) ([]net.Listener, error) {
listeners, err := newTCPListeners(proto, addresses)
if err != nil {
return nil, err
}
for i, inner := range listeners {
listeners[i] = tls.NewListener(inner, tlsCfg)
}
return listeners, nil
}
func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
return &dns.Server{
Addr: address,
@ -521,10 +532,9 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,
TLSConfig: s.tlsCfg,
}
if err := server.ServeTLS(listener, "", ""); err != nil {
if err := server.Serve(listener); err != nil {
errCh <- fmt.Errorf("start https listener failed: %w", err)
}
}()

View File

@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
@ -741,11 +740,12 @@ var _ = Describe("Running DNS server", func() {
cfg.KeyFile = ""
cfg.CertFile = ""
cfg.Ports = config.Ports{
HTTPS: []string{fmt.Sprintf(":%d", GetIntPort(httpsBasePort)+100)},
HTTPS: []string{":0"},
}
sut, err := NewServer(ctx, &cfg)
sut, err := newTLSConfig(&cfg)
Expect(err).Should(Succeed())
Expect(sut.tlsCfg.Certificates).ShouldNot(BeEmpty())
Expect(sut.Certificates).ShouldNot(BeEmpty())
})
})
})