refactor(server): deduplicate `tls.Config` setup

This commit is contained in:
ThinkChaos 2024-03-29 19:27:26 -04:00
parent c389a4a0f4
commit c6e3de4ae0
No known key found for this signature in database
2 changed files with 33 additions and 25 deletions

View File

@ -50,7 +50,7 @@ type Server struct {
queryResolver resolver.ChainedResolver
cfg *config.Config
httpMux *chi.Mux
cert tls.Certificate
tlsCfg *tls.Config
}
func logger() *logrus.Entry {
@ -98,20 +98,38 @@ func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
return
}
func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
var cert tls.Certificate
cert, err := retrieveCertificate(cfg)
if err != nil {
return nil, fmt.Errorf("can't retrieve cert: %w", err)
}
// #nosec G402 // See TLSVersion.validate
res := &tls.Config{
MinVersion: uint16(cfg.MinTLSServeVer),
CipherSuites: tlsCipherSuites(),
Certificates: []tls.Certificate{cert},
}
return res, nil
}
// NewServer creates new server instance with passed config
//
//nolint:funlen
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
var cert tls.Certificate
var tlsCfg *tls.Config
if len(cfg.Ports.HTTPS) > 0 || len(cfg.Ports.TLS) > 0 {
cert, err = retrieveCertificate(cfg)
tlsCfg, err = newTLSConfig(cfg)
if err != nil {
return nil, fmt.Errorf("can't retrieve cert: %w", err)
return nil, err
}
}
dnsServers, err := createServers(cfg, cert)
dnsServers, err := createServers(cfg, tlsCfg)
if err != nil {
return nil, fmt.Errorf("server creation failed: %w", err)
}
@ -147,7 +165,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
cfg: cfg,
httpListeners: httpListeners,
httpsListeners: httpsListeners,
cert: cert,
tlsCfg: tlsCfg,
}
server.printConfiguration()
@ -165,7 +183,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return server, err
}
func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, error) {
func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) {
var dnsServers []*dns.Server
var err *multierror.Error
@ -187,7 +205,7 @@ func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, err
addServers(createUDPServer, cfg.Ports.DNS),
addServers(createTCPServer, cfg.Ports.DNS),
addServers(func(address string) (*dns.Server, error) {
return createTLSServer(cfg, address, cert)
return createTLSServer(address, tlsCfg)
}, cfg.Ports.TLS))
return dnsServers, err.ErrorOrNil()
@ -222,17 +240,12 @@ func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener,
return listeners, nil
}
func createTLSServer(cfg *config.Config, address string, cert tls.Certificate) (*dns.Server, error) {
func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
return &dns.Server{
Addr: address,
Net: "tcp-tls",
//nolint:gosec
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: uint16(cfg.MinTLSServeVer),
CipherSuites: tlsCipherSuites(),
},
Handler: dns.NewServeMux(),
Addr: address,
Net: "tcp-tls",
TLSConfig: tlsCfg,
Handler: dns.NewServeMux(),
NotifyStartedFunc: func() {
logger().Infof("TLS server is up and running on address %s", address)
},
@ -508,12 +521,7 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,
//nolint:gosec
TLSConfig: &tls.Config{
MinVersion: uint16(s.cfg.MinTLSServeVer),
CipherSuites: tlsCipherSuites(),
Certificates: []tls.Certificate{s.cert},
},
TLSConfig: s.tlsCfg,
}
if err := server.ServeTLS(listener, "", ""); err != nil {

View File

@ -745,7 +745,7 @@ var _ = Describe("Running DNS server", func() {
}
sut, err := NewServer(ctx, &cfg)
Expect(err).Should(Succeed())
Expect(sut.cert.Certificate).ShouldNot(BeNil())
Expect(sut.tlsCfg.Certificates).ShouldNot(BeEmpty())
})
})
})